summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py14
1 files changed, 2 insertions, 12 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 23c44f9..fb02e5c 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -46,11 +46,6 @@ def set_chat_users():
ss.bot = Ollama.BOT.value
-def clear_generator_messages():
- log.debug("Clearing generator chat history")
- st.session_state.generator_messages = []
-
-
@st.cache_resource
def load_retriever():
log.debug("Loading retriever model")
@@ -62,7 +57,6 @@ def load_generator(client: str):
log.debug("Loading generator model")
st.session_state.generator = get_generator(client)
set_chat_users()
- clear_generator_messages()
@st.cache_data(show_spinner=False)
@@ -102,12 +96,11 @@ def generate_chat(query: str):
retriever = ss.retriever
generator = ss.generator
- documents = retriever.retrieve(query, limit=5)
+ documents = retriever.retrieve(query, limit=15)
prompt = Prompt(query, documents)
with st.chat_message(ss.bot):
- history = [m.as_dict(ss.client) for m in ss.generator_messages]
- response = st.write_stream(generator.chat(prompt, history))
+ response = st.write_stream(generator.generate(prompt))
display_context(documents)
store_chat(query, response, documents)
@@ -120,7 +113,6 @@ def store_chat(query: str, response: str, documents: List[Document]):
response = Message(role=ss.bot, message=response)
ss.chat.append(query)
ss.chat.append(response)
- ss.generator_messages.append(response)
ss.chat.append(documents)
@@ -150,8 +142,6 @@ def sidebar():
def page():
ss = st.session_state
- if "generator_messages" not in st.session_state:
- ss.generator_messages = []
if "chat" not in st.session_state:
ss.chat = []