summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py51
1 files changed, 24 insertions, 27 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 2fbf8de..f46c24d 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass
-from enum import Enum
from typing import Dict, List
import streamlit as st
@@ -9,27 +8,18 @@ from loguru import logger as log
from rag.generator import MODELS, get_generator
from rag.generator.prompt import Prompt
+from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
from rag.retriever.vector import Document
-class Cohere(Enum):
- USER = "USER"
- BOT = "CHATBOT"
-
-
-class Ollama(Enum):
- USER = "user"
- BOT = "assistant"
-
-
@dataclass
class Message:
role: str
message: str
- def as_dict(self, client: str) -> Dict[str, str]:
- if client == "cohere":
+ def as_dict(self, model: str) -> Dict[str, str]:
+ if model == "cohere":
return {"role": self.role, "message": self.message}
else:
return {"role": self.role, "content": self.message}
@@ -38,12 +28,8 @@ class Message:
def set_chat_users():
log.debug("Setting user and bot value")
ss = st.session_state
- if ss.generator == "cohere":
- ss.user = Cohere.USER.value
- ss.bot = Cohere.BOT.value
- else:
- ss.user = Ollama.USER.value
- ss.bot = Ollama.BOT.value
+ ss.user = "user"
+ ss.bot = "assistant"
@st.cache_resource
@@ -52,13 +38,19 @@ def load_retriever():
st.session_state.retriever = Retriever()
-@st.cache_resource
-def load_generator(client: str):
+# @st.cache_resource
+def load_generator(model: str):
log.debug("Loading generator model")
- st.session_state.generator = get_generator(client)
+ st.session_state.generator = get_generator(model)
set_chat_users()
+# @st.cache_resource
+def load_reranker(model: str):
+ log.debug("Loading reranker model")
+ st.session_state.reranker = get_reranker(model)
+
+
@st.cache_data(show_spinner=False)
def upload(files):
retriever = st.session_state.retriever
@@ -95,11 +87,12 @@ def generate_chat(query: str):
retriever = ss.retriever
generator = ss.generator
+ reranker = ss.reranker
- documents = retriever.retrieve(query, limit=15)
+ documents = retriever.retrieve(query)
prompt = Prompt(query, documents)
- prompt = generator.rerank(prompt)
+ prompt = reranker.rank(prompt)
with st.chat_message(ss.bot):
response = st.write_stream(generator.generate(prompt))
@@ -137,9 +130,12 @@ def sidebar():
upload(files)
st.header("Generative Model")
- st.markdown("Select the model that will be used for generating the answer.")
- st.selectbox("Generative Model", key="client", options=MODELS)
- load_generator(st.session_state.client)
+ st.markdown(
+ "Select the model that will be used for reranking and generating the answer."
+ )
+ st.selectbox("Model", key="model", options=MODELS)
+ load_generator(st.session_state.model)
+ load_reranker(st.session_state.model)
def page():
@@ -157,6 +153,7 @@ def page():
if __name__ == "__main__":
load_dotenv()
st.title("Retrieval Augmented Generation")
+ set_chat_users()
load_retriever()
sidebar()
page()