diff options
Diffstat (limited to 'rag/ui.py')
-rw-r--r-- | rag/ui.py | 21 |
1 files changed, 11 insertions, 10 deletions
@@ -1,16 +1,15 @@ import streamlit as st from langchain_community.document_loaders.blob_loaders import Blob +from .rag import RAG -try: - from rag.rag import RAG - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from rag import RAG - from llm.ollama_generator import Prompt +from .generator import get_generator +from .generator.prompt import Prompt rag = RAG() +MODELS = ["ollama", "cohere"] + def upload_pdfs(): files = st.file_uploader( @@ -26,13 +25,15 @@ def upload_pdfs(): for file in files: source = file.name blob = Blob.from_data(file.read()) - rag.add_pdf_from_blob(blob, source) + rag.add_pdf(blob, source) if __name__ == "__main__": ss = st.session_state st.header("RAG-UI") + model = st.selectbox("Model", options=MODELS) + upload_pdfs() with st.form(key="query"): @@ -56,7 +57,7 @@ if __name__ == "__main__": query = ss.get("query", "") with st.spinner("Searching for documents..."): - documents = rag.search(query) + documents = rag.retrieve(query) prompt = Prompt(query, documents) @@ -69,6 +70,6 @@ if __name__ == "__main__": st.markdown("---") with result_column: + generator = get_generator(model) st.markdown("### Answer") - st.write_stream(rag.retrieve(prompt)) - + st.write_stream(rag.generate(generator, prompt)) |