diff options
Diffstat (limited to 'rag/ui.py')
-rw-r--r-- | rag/ui.py | 45 |
1 files changed, 19 insertions, 26 deletions
@@ -1,40 +1,33 @@ import streamlit as st from langchain_community.document_loaders.blob_loaders import Blob -from .rag import RAG -from .generator import get_generator -from .generator.prompt import Prompt +from dotenv import load_dotenv +from generator import get_generator, MODELS +from generator.prompt import Prompt +from retriever.retriever import Retriever -rag = RAG() -MODELS = ["ollama", "cohere"] +if __name__ == "__main__": + load_dotenv() + retriever = Retriever() + ss = st.session_state + st.header("Retrieval Augmented Generation") + model = st.selectbox("Model", options=MODELS) -def upload_pdfs(): files = st.file_uploader( "Choose pdfs to add to the knowledge base", type="pdf", accept_multiple_files=True, ) - if not files: - return - - with st.spinner("Indexing documents..."): - for file in files: - source = file.name - blob = Blob.from_data(file.read()) - rag.add_pdf(blob, source) - - -if __name__ == "__main__": - ss = st.session_state - st.header("RAG-UI") - - model = st.selectbox("Model", options=MODELS) - - upload_pdfs() + if files: + with st.spinner("Indexing documents..."): + for file in files: + source = file.name + blob = Blob.from_data(file.read()) + retriever.add_pdf(blob=blob, source=source) with st.form(key="query"): query = st.text_area( @@ -51,13 +44,13 @@ if __name__ == "__main__": (b,) = st.columns(1) (result_column, context_column) = st.columns(2) - if submit: + if submit and model: if not query: st.stop() query = ss.get("query", "") with st.spinner("Searching for documents..."): - documents = rag.retrieve(query) + documents = retriever.retrieve(query) prompt = Prompt(query, documents) @@ -72,4 +65,4 @@ if __name__ == "__main__": with result_column: generator = get_generator(model) st.markdown("### Answer") - st.write_stream(rag.generate(generator, prompt)) + st.write_stream(generator.generate(generator, prompt)) |