From d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 8 Apr 2024 22:28:47 +0200 Subject: Wip refactor --- rag/ui.py | 59 +++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 20 deletions(-) (limited to 'rag/ui.py') diff --git a/rag/ui.py b/rag/ui.py index 37c50dd..84dbbeb 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -4,8 +4,10 @@ from langchain_community.document_loaders.blob_loaders import Blob try: from rag.rag import RAG + from rag.llm.ollama_generator import Prompt except ModuleNotFoundError: from rag import RAG + from llm.ollama_generator import Prompt rag = RAG() @@ -16,9 +18,15 @@ def upload_pdfs(): type="pdf", accept_multiple_files=True, ) - for file in files: - blob = Blob.from_data(file.read()) - rag.add_pdf_from_blob(blob) + + if not files: + return + + with st.spinner("Indexing documents..."): + for file in files: + source = file.name + blob = Blob.from_data(file.read()) + rag.add_pdf_from_blob(blob, source) if __name__ == "__main__": @@ -26,30 +34,41 @@ if __name__ == "__main__": st.header("RAG-UI") upload_pdfs() - query = st.text_area( - "query", - key="query", - height=100, - placeholder="Enter query here", - help="", - label_visibility="collapsed", - disabled=False, - ) + + with st.form(key="query"): + query = st.text_area( + "query", + key="query", + height=100, + placeholder="Enter query here", + help="", + label_visibility="collapsed", + disabled=False, + ) + submit = st.form_submit_button("Generate") (b,) = st.columns(1) (result_column, context_column) = st.columns(2) - if b.button("Generate", disabled=False, type="primary", use_container_width=True): + if submit: + if not query: + st.stop() + query = ss.get("query", "") - with st.spinner("Generating answer..."): - response = rag.retrieve(query) + with st.spinner("Searching for documents..."): + documents = rag.search(query) - with result_column: - st.markdown("### Answer") - st.markdown(response.answer) + prompt = Prompt(query, documents) with context_column: st.markdown("### Context") - for c in response.context: - st.markdown(c) + for i, doc in enumerate(documents): + st.markdown(f"### Document {i}") + st.markdown(f"**Title: {doc.title}**") + st.markdown(doc.text) st.markdown("---") + + with result_column: + st.markdown("### Answer") + st.write_stream(rag.retrieve(prompt)) + -- cgit v1.2.3-70-g09d2