From aac821b148c6c0d35b940609dc9b0ddcb053b28e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 19 Jun 2024 02:07:06 +0200 Subject: Still wip on rewrite --- rag/ui.py | 52 +++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 25 deletions(-) (limited to 'rag/ui.py') diff --git a/rag/ui.py b/rag/ui.py index 9a0f2cf..2192ad8 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,15 +1,14 @@ -from dataclasses import dataclass -from typing import Dict, List +from typing import List import streamlit as st from dotenv import load_dotenv from langchain_community.document_loaders.blob_loaders import Blob from loguru import logger as log -from rag.generator import MODELS, get_generator +from rag.generator import MODELS from rag.generator.prompt import Prompt -from rag.retriever.rerank import get_reranker -from rag.retriever.retriever import Retriever +from rag.message import Message +from rag.model import Rag from rag.retriever.vector import Document @@ -19,25 +18,27 @@ def set_chat_users(): ss.user = "user" ss.bot = "assistant" +@st.cache_resource +def load_rag(): + log.debug("Loading Rag...") + st.session_state.rag = Rag() -def load_generator(model: str): - log.debug("Loading rag") - st.session_state.rag = get_generator(model) - -def load_reranker(model: str): - log.debug("Loading reranker model") - st.session_state.reranker = get_reranker(model) +@st.cache_resource +def set_client(client: str): + log.debug("Setting client...") + rag = st.session_state.rag + rag.set_client(client) @st.cache_data(show_spinner=False) def upload(files): - retriever = st.session_state.retriever + rag = st.session_state.rag with st.spinner("Uploading documents..."): for file in files: source = file.name blob = Blob.from_data(file.read()) - retriever.add_pdf(blob=blob, source=source) + rag.retriever.add_pdf(blob=blob, source=source) def display_context(documents: List[Document]): @@ -55,7 +56,7 @@ def display_chat(): if isinstance(msg, list): display_context(msg) else: - st.chat_message(msg.role).write(msg.message) + st.chat_message(msg.role).write(msg.content) def generate_chat(query: str): @@ -66,22 +67,24 @@ def generate_chat(query: str): rag = ss.rag documents = rag.retrieve(query) - Prompt(query, documents, self.client) + prompt = Prompt(query, documents, ss.model) with st.chat_message(ss.bot): - response = st.write_stream(rag.generate(query)) + response = st.write_stream(rag.generate(prompt)) + + rag.add_message(rag.bot, response) display_context(prompt.documents) - store_chat(query, response, prompt.documents) + store_chat(prompt, response) -def store_chat(query: str, response: str, documents: List[Document]): +def store_chat(prompt: Prompt, response: str): log.debug("Storing chat") ss = st.session_state - query = Message(role=ss.user, message=query) - response = Message(role=ss.bot, message=response) + query = Message(ss.user, prompt.query, ss.model) + response = Message(ss.bot, response, ss.model) ss.chat.append(query) ss.chat.append(response) - ss.chat.append(documents) + ss.chat.append(prompt.documents) def sidebar(): @@ -107,8 +110,7 @@ def sidebar(): "Select the model that will be used for reranking and generating the answer." ) st.selectbox("Model", key="model", options=MODELS) - load_generator(st.session_state.model) - load_reranker(st.session_state.model) + set_client(st.session_state.model) def page(): @@ -127,6 +129,6 @@ if __name__ == "__main__": load_dotenv() st.title("Retrieval Augmented Generation") set_chat_users() - load_retriever() + load_rag() sidebar() page() -- cgit v1.2.3-70-g09d2