summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py52
1 files changed, 27 insertions, 25 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 9a0f2cf..2192ad8 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,15 +1,14 @@
-from dataclasses import dataclass
-from typing import Dict, List
+from typing import List
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders.blob_loaders import Blob
from loguru import logger as log
-from rag.generator import MODELS, get_generator
+from rag.generator import MODELS
from rag.generator.prompt import Prompt
-from rag.retriever.rerank import get_reranker
-from rag.retriever.retriever import Retriever
+from rag.message import Message
+from rag.model import Rag
from rag.retriever.vector import Document
@@ -19,25 +18,27 @@ def set_chat_users():
ss.user = "user"
ss.bot = "assistant"
+@st.cache_resource
+def load_rag():
+ log.debug("Loading Rag...")
+ st.session_state.rag = Rag()
-def load_generator(model: str):
- log.debug("Loading rag")
- st.session_state.rag = get_generator(model)
-
-def load_reranker(model: str):
- log.debug("Loading reranker model")
- st.session_state.reranker = get_reranker(model)
+@st.cache_resource
+def set_client(client: str):
+ log.debug("Setting client...")
+ rag = st.session_state.rag
+ rag.set_client(client)
@st.cache_data(show_spinner=False)
def upload(files):
- retriever = st.session_state.retriever
+ rag = st.session_state.rag
with st.spinner("Uploading documents..."):
for file in files:
source = file.name
blob = Blob.from_data(file.read())
- retriever.add_pdf(blob=blob, source=source)
+ rag.retriever.add_pdf(blob=blob, source=source)
def display_context(documents: List[Document]):
@@ -55,7 +56,7 @@ def display_chat():
if isinstance(msg, list):
display_context(msg)
else:
- st.chat_message(msg.role).write(msg.message)
+ st.chat_message(msg.role).write(msg.content)
def generate_chat(query: str):
@@ -66,22 +67,24 @@ def generate_chat(query: str):
rag = ss.rag
documents = rag.retrieve(query)
- Prompt(query, documents, self.client)
+ prompt = Prompt(query, documents, ss.model)
with st.chat_message(ss.bot):
- response = st.write_stream(rag.generate(query))
+ response = st.write_stream(rag.generate(prompt))
+
+ rag.add_message(rag.bot, response)
display_context(prompt.documents)
- store_chat(query, response, prompt.documents)
+ store_chat(prompt, response)
-def store_chat(query: str, response: str, documents: List[Document]):
+def store_chat(prompt: Prompt, response: str):
log.debug("Storing chat")
ss = st.session_state
- query = Message(role=ss.user, message=query)
- response = Message(role=ss.bot, message=response)
+ query = Message(ss.user, prompt.query, ss.model)
+ response = Message(ss.bot, response, ss.model)
ss.chat.append(query)
ss.chat.append(response)
- ss.chat.append(documents)
+ ss.chat.append(prompt.documents)
def sidebar():
@@ -107,8 +110,7 @@ def sidebar():
"Select the model that will be used for reranking and generating the answer."
)
st.selectbox("Model", key="model", options=MODELS)
- load_generator(st.session_state.model)
- load_reranker(st.session_state.model)
+ set_client(st.session_state.model)
def page():
@@ -127,6 +129,6 @@ if __name__ == "__main__":
load_dotenv()
st.title("Retrieval Augmented Generation")
set_chat_users()
- load_retriever()
+ load_rag()
sidebar()
page()