From 4968ed48ed1adb267b910b28fdda0db115ba1b19 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 13 Apr 2024 12:32:23 +0200 Subject: Fix ui bug --- rag/ui.py | 78 +++++++++++++++++++++++++++++++-------------------------------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/rag/ui.py b/rag/ui.py index abaf284..253ad87 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -2,10 +2,10 @@ from dataclasses import dataclass from enum import Enum from typing import Dict, List -from loguru import logger as log import streamlit as st from dotenv import load_dotenv from langchain_community.document_loaders.blob_loaders import Blob +from loguru import logger as log from rag.generator import MODELS, get_generator from rag.generator.prompt import Prompt @@ -46,9 +46,9 @@ def set_chat_users(): ss.bot = Ollama.BOT.value -def clear_messages(): - log.debug("Clearing llm chat history") - st.session_state.messages = [] +def clear_generator_messages(): + log.debug("Clearing generator chat history") + st.session_state.generator_messages = [] @st.cache_resource @@ -62,43 +62,19 @@ def load_generator(client: str): log.debug("Loading generator model") st.session_state.generator = get_generator(client) set_chat_users() - clear_messages() + clear_generator_messages() @st.cache_data(show_spinner=False) def upload(files): - with st.spinner("Indexing documents..."): - retriever = st.session_state.retriever + retriever = st.session_state.retriever + with st.spinner("Uploading documents..."): for file in files: source = file.name blob = Blob.from_data(file.read()) retriever.add_pdf(blob=blob, source=source) -def sidebar(): - with st.sidebar: - st.header("Grouding") - st.markdown( - ( - "These files will be uploaded to the knowledge base and used " - "as groudning if they are relevant to the question." - ) - ) - - files = st.file_uploader( - "Choose pdfs to add to the knowledge base", - type="pdf", - accept_multiple_files=True, - ) - - upload(files) - - st.header("Generative Model") - st.markdown("Select the model that will be used for generating the answer.") - st.selectbox("Generative Model", key="client", options=MODELS) - load_generator(st.session_state.client) - - def display_context(documents: List[Document]): with st.popover("See Context"): for i, doc in enumerate(documents): @@ -119,20 +95,20 @@ def display_chat(): def generate_chat(query: str): ss = st.session_state + with st.chat_message(ss.user): st.write(query) retriever = ss.retriever generator = ss.generator - with st.spinner("Searching for documents..."): - documents = retriever.retrieve(query) - + documents = retriever.retrieve(query) prompt = Prompt(query, documents) with st.chat_message(ss.bot): - history = [m.as_dict(ss.client) for m in ss.messages] + history = [m.as_dict(ss.client) for m in ss.generator_messages] response = st.write_stream(generator.chat(prompt, history)) + display_context(documents) store_chat(query, response, documents) @@ -144,22 +120,44 @@ def store_chat(query: str, response: str, documents: List[Document]): response = Message(role=ss.bot, message=response) ss.chat.append(query) ss.chat.append(response) - ss.messages.append(response) + ss.generator_messages.append(response) ss.chat.append(documents) +def sidebar(): + with st.sidebar: + st.header("Grouding") + st.markdown( + ( + "These files will be uploaded to the knowledge base and used " + "as groudning if they are relevant to the question." + ) + ) + + files = st.file_uploader( + "Choose pdfs to add to the knowledge base", + type="pdf", + accept_multiple_files=True, + ) + + upload(files) + + st.header("Generative Model") + st.markdown("Select the model that will be used for generating the answer.") + st.selectbox("Generative Model", key="client", options=MODELS) + load_generator(st.session_state.client) + + def page(): ss = st.session_state - - if "messages" not in st.session_state: - ss.messages = [] + if "generator_messages" not in st.session_state: + ss.generator_messages = [] if "chat" not in st.session_state: ss.chat = [] display_chat() query = st.chat_input("Enter query here") - if query: generate_chat(query) -- cgit v1.2.3-70-g09d2