summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py45
1 files changed, 19 insertions, 26 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 83a22e2..5083bb4 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,40 +1,33 @@
import streamlit as st
from langchain_community.document_loaders.blob_loaders import Blob
-from .rag import RAG
-from .generator import get_generator
-from .generator.prompt import Prompt
+from dotenv import load_dotenv
+from generator import get_generator, MODELS
+from generator.prompt import Prompt
+from retriever.retriever import Retriever
-rag = RAG()
-MODELS = ["ollama", "cohere"]
+if __name__ == "__main__":
+ load_dotenv()
+ retriever = Retriever()
+ ss = st.session_state
+ st.header("Retrieval Augmented Generation")
+ model = st.selectbox("Model", options=MODELS)
-def upload_pdfs():
files = st.file_uploader(
"Choose pdfs to add to the knowledge base",
type="pdf",
accept_multiple_files=True,
)
- if not files:
- return
-
- with st.spinner("Indexing documents..."):
- for file in files:
- source = file.name
- blob = Blob.from_data(file.read())
- rag.add_pdf(blob, source)
-
-
-if __name__ == "__main__":
- ss = st.session_state
- st.header("RAG-UI")
-
- model = st.selectbox("Model", options=MODELS)
-
- upload_pdfs()
+ if files:
+ with st.spinner("Indexing documents..."):
+ for file in files:
+ source = file.name
+ blob = Blob.from_data(file.read())
+ retriever.add_pdf(blob=blob, source=source)
with st.form(key="query"):
query = st.text_area(
@@ -51,13 +44,13 @@ if __name__ == "__main__":
(b,) = st.columns(1)
(result_column, context_column) = st.columns(2)
- if submit:
+ if submit and model:
if not query:
st.stop()
query = ss.get("query", "")
with st.spinner("Searching for documents..."):
- documents = rag.retrieve(query)
+ documents = retriever.retrieve(query)
prompt = Prompt(query, documents)
@@ -72,4 +65,4 @@ if __name__ == "__main__":
with result_column:
generator = get_generator(model)
st.markdown("### Answer")
- st.write_stream(rag.generate(generator, prompt))
+ st.write_stream(generator.generate(generator, prompt))