summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py23
1 files changed, 5 insertions, 18 deletions
diff --git a/rag/ui.py b/rag/ui.py
index ddb3d78..a453f47 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -20,15 +20,9 @@ def set_chat_users():
ss.bot = "assistant"
-@st.cache_resource
-def load_retriever():
- log.debug("Loading retriever model")
- st.session_state.retriever = Retriever()
-
-
def load_generator(model: str):
- log.debug("Loading generator model")
- st.session_state.generator = get_generator(model)
+ log.debug("Loading rag")
+ st.session_state.rag = get_generator(model)
def load_reranker(model: str):
@@ -70,17 +64,10 @@ def generate_chat(query: str):
with st.chat_message(ss.user):
st.write(query)
- retriever = ss.retriever
- generator = ss.generator
- reranker = ss.reranker
-
- documents = retriever.retrieve(query)
- prompt = Prompt(query, documents)
-
- prompt = reranker.rank(prompt)
-
+ rag = ss.rag
+ prompt = rag.retrieve(query)
with st.chat_message(ss.bot):
- response = st.write_stream(generator.generate(prompt))
+ response = st.write_stream(rag.generate(query))
display_context(prompt.documents)
store_chat(query, response, prompt.documents)