diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 23:56:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 23:56:13 +0200 |
commit | 730923f30c0e74fbbcdcf557763c549f554e8bdb (patch) | |
tree | ed8b1d48338872c9f0d3e7c659ed4aa5ccc621fc /rag | |
parent | 040aff556698ed96fa8b62f10033a0d9e8e9d0f5 (diff) |
Fix ui
Diffstat (limited to 'rag')
-rw-r--r-- | rag/ui.py | 38 |
1 files changed, 28 insertions, 10 deletions
@@ -1,18 +1,42 @@ +from typing import Type + import streamlit as st from dotenv import load_dotenv from langchain_community.document_loaders.blob_loaders import Blob from rag.generator import MODELS, get_generator +from rag.generator.abstract import AbstractGenerator from rag.generator.prompt import Prompt from rag.retriever.retriever import Retriever + +@st.cache_resource +def load_retriever() -> Retriever: + return Retriever() + + +@st.cache_resource +def load_generator(model: str) -> Type[AbstractGenerator]: + return get_generator(model) + + +@st.cache_data(show_spinner=False) +def upload(files): + with st.spinner("Indexing documents..."): + for file in files: + source = file.name + blob = Blob.from_data(file.read()) + retriever.add_pdf(blob=blob, source=source) + + if __name__ == "__main__": load_dotenv() - retriever = Retriever() + retriever = load_retriever() ss = st.session_state - st.header("Retrieval Augmented Generation") + st.title("Retrieval Augmented Generation") model = st.selectbox("Generative Model", options=MODELS) + generator = load_generator(model) files = st.file_uploader( "Choose pdfs to add to the knowledge base", @@ -20,12 +44,7 @@ if __name__ == "__main__": accept_multiple_files=True, ) - if files: - with st.spinner("Indexing documents..."): - for file in files: - source = file.name - blob = Blob.from_data(file.read()) - retriever.add_pdf(blob=blob, source=source) + upload(files) with st.form(key="query"): query = st.text_area( @@ -41,7 +60,7 @@ if __name__ == "__main__": (result_column, context_column) = st.columns(2) - if submit and model: + if submit: if not query: st.stop() @@ -60,6 +79,5 @@ if __name__ == "__main__": st.markdown("---") with result_column: - generator = get_generator(model) st.markdown("### Answer") st.write_stream(generator.generate(prompt)) |