diff options
Diffstat (limited to 'rag/ui.py')
-rw-r--r-- | rag/ui.py | 51 |
1 files changed, 24 insertions, 27 deletions
@@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import Dict, List import streamlit as st @@ -9,27 +8,18 @@ from loguru import logger as log from rag.generator import MODELS, get_generator from rag.generator.prompt import Prompt +from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever from rag.retriever.vector import Document -class Cohere(Enum): - USER = "USER" - BOT = "CHATBOT" - - -class Ollama(Enum): - USER = "user" - BOT = "assistant" - - @dataclass class Message: role: str message: str - def as_dict(self, client: str) -> Dict[str, str]: - if client == "cohere": + def as_dict(self, model: str) -> Dict[str, str]: + if model == "cohere": return {"role": self.role, "message": self.message} else: return {"role": self.role, "content": self.message} @@ -38,12 +28,8 @@ class Message: def set_chat_users(): log.debug("Setting user and bot value") ss = st.session_state - if ss.generator == "cohere": - ss.user = Cohere.USER.value - ss.bot = Cohere.BOT.value - else: - ss.user = Ollama.USER.value - ss.bot = Ollama.BOT.value + ss.user = "user" + ss.bot = "assistant" @st.cache_resource @@ -52,13 +38,19 @@ def load_retriever(): st.session_state.retriever = Retriever() -@st.cache_resource -def load_generator(client: str): +# @st.cache_resource +def load_generator(model: str): log.debug("Loading generator model") - st.session_state.generator = get_generator(client) + st.session_state.generator = get_generator(model) set_chat_users() +# @st.cache_resource +def load_reranker(model: str): + log.debug("Loading reranker model") + st.session_state.reranker = get_reranker(model) + + @st.cache_data(show_spinner=False) def upload(files): retriever = st.session_state.retriever @@ -95,11 +87,12 @@ def generate_chat(query: str): retriever = ss.retriever generator = ss.generator + reranker = ss.reranker - documents = retriever.retrieve(query, limit=15) + documents = retriever.retrieve(query) prompt = Prompt(query, documents) - prompt = generator.rerank(prompt) + prompt = reranker.rank(prompt) with st.chat_message(ss.bot): response = st.write_stream(generator.generate(prompt)) @@ -137,9 +130,12 @@ def sidebar(): upload(files) st.header("Generative Model") - st.markdown("Select the model that will be used for generating the answer.") - st.selectbox("Generative Model", key="client", options=MODELS) - load_generator(st.session_state.client) + st.markdown( + "Select the model that will be used for reranking and generating the answer." + ) + st.selectbox("Model", key="model", options=MODELS) + load_generator(st.session_state.model) + load_reranker(st.session_state.model) def page(): @@ -157,6 +153,7 @@ def page(): if __name__ == "__main__": load_dotenv() st.title("Retrieval Augmented Generation") + set_chat_users() load_retriever() sidebar() page() |