summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py59
1 files changed, 39 insertions, 20 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 37c50dd..84dbbeb 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -4,8 +4,10 @@ from langchain_community.document_loaders.blob_loaders import Blob
try:
from rag.rag import RAG
+ from rag.llm.ollama_generator import Prompt
except ModuleNotFoundError:
from rag import RAG
+ from llm.ollama_generator import Prompt
rag = RAG()
@@ -16,9 +18,15 @@ def upload_pdfs():
type="pdf",
accept_multiple_files=True,
)
- for file in files:
- blob = Blob.from_data(file.read())
- rag.add_pdf_from_blob(blob)
+
+ if not files:
+ return
+
+ with st.spinner("Indexing documents..."):
+ for file in files:
+ source = file.name
+ blob = Blob.from_data(file.read())
+ rag.add_pdf_from_blob(blob, source)
if __name__ == "__main__":
@@ -26,30 +34,41 @@ if __name__ == "__main__":
st.header("RAG-UI")
upload_pdfs()
- query = st.text_area(
- "query",
- key="query",
- height=100,
- placeholder="Enter query here",
- help="",
- label_visibility="collapsed",
- disabled=False,
- )
+
+ with st.form(key="query"):
+ query = st.text_area(
+ "query",
+ key="query",
+ height=100,
+ placeholder="Enter query here",
+ help="",
+ label_visibility="collapsed",
+ disabled=False,
+ )
+ submit = st.form_submit_button("Generate")
(b,) = st.columns(1)
(result_column, context_column) = st.columns(2)
- if b.button("Generate", disabled=False, type="primary", use_container_width=True):
+ if submit:
+ if not query:
+ st.stop()
+
query = ss.get("query", "")
- with st.spinner("Generating answer..."):
- response = rag.retrieve(query)
+ with st.spinner("Searching for documents..."):
+ documents = rag.search(query)
- with result_column:
- st.markdown("### Answer")
- st.markdown(response.answer)
+ prompt = Prompt(query, documents)
with context_column:
st.markdown("### Context")
- for c in response.context:
- st.markdown(c)
+ for i, doc in enumerate(documents):
+ st.markdown(f"### Document {i}")
+ st.markdown(f"**Title: {doc.title}**")
+ st.markdown(doc.text)
st.markdown("---")
+
+ with result_column:
+ st.markdown("### Answer")
+ st.write_stream(rag.retrieve(prompt))
+