From 91ddb3672e514fa9824609ff047d7cab0c65631a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 9 Apr 2024 00:14:00 +0200 Subject: Refactor --- rag/ui.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'rag/ui.py') diff --git a/rag/ui.py b/rag/ui.py index 84dbbeb..83a22e2 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,16 +1,15 @@ import streamlit as st from langchain_community.document_loaders.blob_loaders import Blob +from .rag import RAG -try: - from rag.rag import RAG - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from rag import RAG - from llm.ollama_generator import Prompt +from .generator import get_generator +from .generator.prompt import Prompt rag = RAG() +MODELS = ["ollama", "cohere"] + def upload_pdfs(): files = st.file_uploader( @@ -26,13 +25,15 @@ def upload_pdfs(): for file in files: source = file.name blob = Blob.from_data(file.read()) - rag.add_pdf_from_blob(blob, source) + rag.add_pdf(blob, source) if __name__ == "__main__": ss = st.session_state st.header("RAG-UI") + model = st.selectbox("Model", options=MODELS) + upload_pdfs() with st.form(key="query"): @@ -56,7 +57,7 @@ if __name__ == "__main__": query = ss.get("query", "") with st.spinner("Searching for documents..."): - documents = rag.search(query) + documents = rag.retrieve(query) prompt = Prompt(query, documents) @@ -69,6 +70,6 @@ if __name__ == "__main__": st.markdown("---") with result_column: + generator = get_generator(model) st.markdown("### Answer") - st.write_stream(rag.retrieve(prompt)) - + st.write_stream(rag.generate(generator, prompt)) -- cgit v1.2.3-70-g09d2