summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-13 12:32:23 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-13 12:32:23 +0200
commit4968ed48ed1adb267b910b28fdda0db115ba1b19 (patch)
tree67401c4c0a283b8d182b5a632b8db5233d0600be /rag/ui.py
parent72d1caf92115d90ae789de1cffed29406f2a0a39 (diff)
Fix ui bug
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py78
1 files changed, 38 insertions, 40 deletions
diff --git a/rag/ui.py b/rag/ui.py
index abaf284..253ad87 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -2,10 +2,10 @@ from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
-from loguru import logger as log
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders.blob_loaders import Blob
+from loguru import logger as log
from rag.generator import MODELS, get_generator
from rag.generator.prompt import Prompt
@@ -46,9 +46,9 @@ def set_chat_users():
ss.bot = Ollama.BOT.value
-def clear_messages():
- log.debug("Clearing llm chat history")
- st.session_state.messages = []
+def clear_generator_messages():
+ log.debug("Clearing generator chat history")
+ st.session_state.generator_messages = []
@st.cache_resource
@@ -62,43 +62,19 @@ def load_generator(client: str):
log.debug("Loading generator model")
st.session_state.generator = get_generator(client)
set_chat_users()
- clear_messages()
+ clear_generator_messages()
@st.cache_data(show_spinner=False)
def upload(files):
- with st.spinner("Indexing documents..."):
- retriever = st.session_state.retriever
+ retriever = st.session_state.retriever
+ with st.spinner("Uploading documents..."):
for file in files:
source = file.name
blob = Blob.from_data(file.read())
retriever.add_pdf(blob=blob, source=source)
-def sidebar():
- with st.sidebar:
- st.header("Grouding")
- st.markdown(
- (
- "These files will be uploaded to the knowledge base and used "
- "as groudning if they are relevant to the question."
- )
- )
-
- files = st.file_uploader(
- "Choose pdfs to add to the knowledge base",
- type="pdf",
- accept_multiple_files=True,
- )
-
- upload(files)
-
- st.header("Generative Model")
- st.markdown("Select the model that will be used for generating the answer.")
- st.selectbox("Generative Model", key="client", options=MODELS)
- load_generator(st.session_state.client)
-
-
def display_context(documents: List[Document]):
with st.popover("See Context"):
for i, doc in enumerate(documents):
@@ -119,20 +95,20 @@ def display_chat():
def generate_chat(query: str):
ss = st.session_state
+
with st.chat_message(ss.user):
st.write(query)
retriever = ss.retriever
generator = ss.generator
- with st.spinner("Searching for documents..."):
- documents = retriever.retrieve(query)
-
+ documents = retriever.retrieve(query)
prompt = Prompt(query, documents)
with st.chat_message(ss.bot):
- history = [m.as_dict(ss.client) for m in ss.messages]
+ history = [m.as_dict(ss.client) for m in ss.generator_messages]
response = st.write_stream(generator.chat(prompt, history))
+
display_context(documents)
store_chat(query, response, documents)
@@ -144,22 +120,44 @@ def store_chat(query: str, response: str, documents: List[Document]):
response = Message(role=ss.bot, message=response)
ss.chat.append(query)
ss.chat.append(response)
- ss.messages.append(response)
+ ss.generator_messages.append(response)
ss.chat.append(documents)
+def sidebar():
+ with st.sidebar:
+ st.header("Grouding")
+ st.markdown(
+ (
+ "These files will be uploaded to the knowledge base and used "
+ "as groudning if they are relevant to the question."
+ )
+ )
+
+ files = st.file_uploader(
+ "Choose pdfs to add to the knowledge base",
+ type="pdf",
+ accept_multiple_files=True,
+ )
+
+ upload(files)
+
+ st.header("Generative Model")
+ st.markdown("Select the model that will be used for generating the answer.")
+ st.selectbox("Generative Model", key="client", options=MODELS)
+ load_generator(st.session_state.client)
+
+
def page():
ss = st.session_state
-
- if "messages" not in st.session_state:
- ss.messages = []
+ if "generator_messages" not in st.session_state:
+ ss.generator_messages = []
if "chat" not in st.session_state:
ss.chat = []
display_chat()
query = st.chat_input("Enter query here")
-
if query:
generate_chat(query)