From b1ff0c55422d7b0af2c379679b8721014ef36926 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 18 Jun 2024 01:37:32 +0200 Subject: Wip rewrite --- rag/ui.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) (limited to 'rag/ui.py') diff --git a/rag/ui.py b/rag/ui.py index ddb3d78..a453f47 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -20,15 +20,9 @@ def set_chat_users(): ss.bot = "assistant" -@st.cache_resource -def load_retriever(): - log.debug("Loading retriever model") - st.session_state.retriever = Retriever() - - def load_generator(model: str): - log.debug("Loading generator model") - st.session_state.generator = get_generator(model) + log.debug("Loading rag") + st.session_state.rag = get_generator(model) def load_reranker(model: str): @@ -70,17 +64,10 @@ def generate_chat(query: str): with st.chat_message(ss.user): st.write(query) - retriever = ss.retriever - generator = ss.generator - reranker = ss.reranker - - documents = retriever.retrieve(query) - prompt = Prompt(query, documents) - - prompt = reranker.rank(prompt) - + rag = ss.rag + prompt = rag.retrieve(query) with st.chat_message(ss.bot): - response = st.write_stream(generator.generate(prompt)) + response = st.write_stream(rag.generate(query)) display_context(prompt.documents) store_chat(query, response, prompt.documents) -- cgit v1.2.3-70-g09d2