From 730923f30c0e74fbbcdcf557763c549f554e8bdb Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 9 Apr 2024 23:56:13 +0200 Subject: Fix ui --- rag/ui.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/rag/ui.py b/rag/ui.py index 211c4b8..5a5e01a 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,18 +1,42 @@ +from typing import Type + import streamlit as st from dotenv import load_dotenv from langchain_community.document_loaders.blob_loaders import Blob from rag.generator import MODELS, get_generator +from rag.generator.abstract import AbstractGenerator from rag.generator.prompt import Prompt from rag.retriever.retriever import Retriever + +@st.cache_resource +def load_retriever() -> Retriever: + return Retriever() + + +@st.cache_resource +def load_generator(model: str) -> Type[AbstractGenerator]: + return get_generator(model) + + +@st.cache_data(show_spinner=False) +def upload(files): + with st.spinner("Indexing documents..."): + for file in files: + source = file.name + blob = Blob.from_data(file.read()) + retriever.add_pdf(blob=blob, source=source) + + if __name__ == "__main__": load_dotenv() - retriever = Retriever() + retriever = load_retriever() ss = st.session_state - st.header("Retrieval Augmented Generation") + st.title("Retrieval Augmented Generation") model = st.selectbox("Generative Model", options=MODELS) + generator = load_generator(model) files = st.file_uploader( "Choose pdfs to add to the knowledge base", @@ -20,12 +44,7 @@ if __name__ == "__main__": accept_multiple_files=True, ) - if files: - with st.spinner("Indexing documents..."): - for file in files: - source = file.name - blob = Blob.from_data(file.read()) - retriever.add_pdf(blob=blob, source=source) + upload(files) with st.form(key="query"): query = st.text_area( @@ -41,7 +60,7 @@ if __name__ == "__main__": (result_column, context_column) = st.columns(2) - if submit and model: + if submit: if not query: st.stop() @@ -60,6 +79,5 @@ if __name__ == "__main__": st.markdown("---") with result_column: - generator = get_generator(model) st.markdown("### Answer") st.write_stream(generator.generate(prompt)) -- cgit v1.2.3-70-g09d2