summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rag/ui.py38
1 files changed, 28 insertions, 10 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 211c4b8..5a5e01a 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,18 +1,42 @@
+from typing import Type
+
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders.blob_loaders import Blob
from rag.generator import MODELS, get_generator
+from rag.generator.abstract import AbstractGenerator
from rag.generator.prompt import Prompt
from rag.retriever.retriever import Retriever
+
+@st.cache_resource
+def load_retriever() -> Retriever:
+ return Retriever()
+
+
+@st.cache_resource
+def load_generator(model: str) -> Type[AbstractGenerator]:
+ return get_generator(model)
+
+
+@st.cache_data(show_spinner=False)
+def upload(files):
+ with st.spinner("Indexing documents..."):
+ for file in files:
+ source = file.name
+ blob = Blob.from_data(file.read())
+ retriever.add_pdf(blob=blob, source=source)
+
+
if __name__ == "__main__":
load_dotenv()
- retriever = Retriever()
+ retriever = load_retriever()
ss = st.session_state
- st.header("Retrieval Augmented Generation")
+ st.title("Retrieval Augmented Generation")
model = st.selectbox("Generative Model", options=MODELS)
+ generator = load_generator(model)
files = st.file_uploader(
"Choose pdfs to add to the knowledge base",
@@ -20,12 +44,7 @@ if __name__ == "__main__":
accept_multiple_files=True,
)
- if files:
- with st.spinner("Indexing documents..."):
- for file in files:
- source = file.name
- blob = Blob.from_data(file.read())
- retriever.add_pdf(blob=blob, source=source)
+ upload(files)
with st.form(key="query"):
query = st.text_area(
@@ -41,7 +60,7 @@ if __name__ == "__main__":
(result_column, context_column) = st.columns(2)
- if submit and model:
+ if submit:
if not query:
st.stop()
@@ -60,6 +79,5 @@ if __name__ == "__main__":
st.markdown("---")
with result_column:
- generator = get_generator(model)
st.markdown("### Answer")
st.write_stream(generator.generate(prompt))