From 95305f59df84caded50286b1a57b6075e48725a8 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Apr 2024 01:10:43 +0200 Subject: Rerank working llama3 sucks at rag --- rag/ui.py | 51 ++++++++++++++++++++++++--------------------------- 1 file changed, 24 insertions(+), 27 deletions(-) (limited to 'rag/ui.py') diff --git a/rag/ui.py b/rag/ui.py index 2fbf8de..f46c24d 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import Dict, List import streamlit as st @@ -9,27 +8,18 @@ from loguru import logger as log from rag.generator import MODELS, get_generator from rag.generator.prompt import Prompt +from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever from rag.retriever.vector import Document -class Cohere(Enum): - USER = "USER" - BOT = "CHATBOT" - - -class Ollama(Enum): - USER = "user" - BOT = "assistant" - - @dataclass class Message: role: str message: str - def as_dict(self, client: str) -> Dict[str, str]: - if client == "cohere": + def as_dict(self, model: str) -> Dict[str, str]: + if model == "cohere": return {"role": self.role, "message": self.message} else: return {"role": self.role, "content": self.message} @@ -38,12 +28,8 @@ class Message: def set_chat_users(): log.debug("Setting user and bot value") ss = st.session_state - if ss.generator == "cohere": - ss.user = Cohere.USER.value - ss.bot = Cohere.BOT.value - else: - ss.user = Ollama.USER.value - ss.bot = Ollama.BOT.value + ss.user = "user" + ss.bot = "assistant" @st.cache_resource @@ -52,13 +38,19 @@ def load_retriever(): st.session_state.retriever = Retriever() -@st.cache_resource -def load_generator(client: str): +# @st.cache_resource +def load_generator(model: str): log.debug("Loading generator model") - st.session_state.generator = get_generator(client) + st.session_state.generator = get_generator(model) set_chat_users() +# @st.cache_resource +def load_reranker(model: str): + log.debug("Loading reranker model") + st.session_state.reranker = get_reranker(model) + + @st.cache_data(show_spinner=False) def upload(files): retriever = st.session_state.retriever @@ -95,11 +87,12 @@ def generate_chat(query: str): retriever = ss.retriever generator = ss.generator + reranker = ss.reranker - documents = retriever.retrieve(query, limit=15) + documents = retriever.retrieve(query) prompt = Prompt(query, documents) - prompt = generator.rerank(prompt) + prompt = reranker.rank(prompt) with st.chat_message(ss.bot): response = st.write_stream(generator.generate(prompt)) @@ -137,9 +130,12 @@ def sidebar(): upload(files) st.header("Generative Model") - st.markdown("Select the model that will be used for generating the answer.") - st.selectbox("Generative Model", key="client", options=MODELS) - load_generator(st.session_state.client) + st.markdown( + "Select the model that will be used for reranking and generating the answer." + ) + st.selectbox("Model", key="model", options=MODELS) + load_generator(st.session_state.model) + load_reranker(st.session_state.model) def page(): @@ -157,6 +153,7 @@ def page(): if __name__ == "__main__": load_dotenv() st.title("Retrieval Augmented Generation") + set_chat_users() load_retriever() sidebar() page() -- cgit v1.2.3-70-g09d2