summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py160
1 files changed, 121 insertions, 39 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 40da9dd..abaf284 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,38 +1,81 @@
-from typing import Type
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, List
+from loguru import logger as log
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders.blob_loaders import Blob
from rag.generator import MODELS, get_generator
-from rag.generator.abstract import AbstractGenerator
from rag.generator.prompt import Prompt
from rag.retriever.retriever import Retriever
+from rag.retriever.vector import Document
+
+
+class Cohere(Enum):
+ USER = "USER"
+ BOT = "CHATBOT"
+
+
+class Ollama(Enum):
+ USER = "user"
+ BOT = "assistant"
+
+
+@dataclass
+class Message:
+ role: str
+ message: str
+
+ def as_dict(self, client: str) -> Dict[str, str]:
+ if client == "cohere":
+ return {"role": self.role, "message": self.message}
+ else:
+ return {"role": self.role, "content": self.message}
+
+
+def set_chat_users():
+ log.debug("Setting user and bot value")
+ ss = st.session_state
+ if ss.generator == "cohere":
+ ss.user = Cohere.USER.value
+ ss.bot = Cohere.BOT.value
+ else:
+ ss.user = Ollama.USER.value
+ ss.bot = Ollama.BOT.value
+
+
+def clear_messages():
+ log.debug("Clearing llm chat history")
+ st.session_state.messages = []
@st.cache_resource
-def load_retriever() -> Retriever:
- return Retriever()
+def load_retriever():
+ log.debug("Loading retriever model")
+ st.session_state.retriever = Retriever()
@st.cache_resource
-def load_generator(model: str) -> Type[AbstractGenerator]:
- return get_generator(model)
+def load_generator(client: str):
+ log.debug("Loading generator model")
+ st.session_state.generator = get_generator(client)
+ set_chat_users()
+ clear_messages()
@st.cache_data(show_spinner=False)
def upload(files):
with st.spinner("Indexing documents..."):
+ retriever = st.session_state.retriever
for file in files:
source = file.name
blob = Blob.from_data(file.read())
retriever.add_pdf(blob=blob, source=source)
-if __name__ == "__main__":
- load_dotenv()
- retriever = load_retriever()
-
+def sidebar():
with st.sidebar:
st.header("Grouding")
st.markdown(
@@ -52,39 +95,78 @@ if __name__ == "__main__":
st.header("Generative Model")
st.markdown("Select the model that will be used for generating the answer.")
- model = st.selectbox("Generative Model", options=MODELS)
- generator = load_generator(model)
+ st.selectbox("Generative Model", key="client", options=MODELS)
+ load_generator(st.session_state.client)
- st.title("Retrieval Augmented Generation")
- with st.form(key="query"):
- query = st.text_area(
- "query",
- key="query",
- height=100,
- placeholder="Enter query here",
- help="",
- label_visibility="collapsed",
- disabled=False,
- )
- submit = st.form_submit_button("Generate")
+def display_context(documents: List[Document]):
+ with st.popover("See Context"):
+ for i, doc in enumerate(documents):
+ st.markdown(f"### Document {i}")
+ st.markdown(f"**Title: {doc.title}**")
+ st.markdown(doc.text)
+ st.markdown("---")
- (result_column, context_column) = st.columns(2)
- if submit and query:
- with st.spinner("Searching for documents..."):
- documents = retriever.retrieve(query)
+def display_chat():
+ ss = st.session_state
+ for msg in ss.chat:
+ if isinstance(msg, list):
+ display_context(msg)
+ else:
+ st.chat_message(msg.role).write(msg.message)
- prompt = Prompt(query, documents)
- with context_column:
- st.markdown("### Context")
- for i, doc in enumerate(documents):
- st.markdown(f"### Document {i}")
- st.markdown(f"**Title: {doc.title}**")
- st.markdown(doc.text)
- st.markdown("---")
+def generate_chat(query: str):
+ ss = st.session_state
+ with st.chat_message(ss.user):
+ st.write(query)
- with result_column:
- st.markdown("### Answer")
- st.write_stream(generator.generate(prompt))
+ retriever = ss.retriever
+ generator = ss.generator
+
+ with st.spinner("Searching for documents..."):
+ documents = retriever.retrieve(query)
+
+ prompt = Prompt(query, documents)
+
+ with st.chat_message(ss.bot):
+ history = [m.as_dict(ss.client) for m in ss.messages]
+ response = st.write_stream(generator.chat(prompt, history))
+ display_context(documents)
+ store_chat(query, response, documents)
+
+
+def store_chat(query: str, response: str, documents: List[Document]):
+ log.debug("Storing chat")
+ ss = st.session_state
+ query = Message(role=ss.user, message=query)
+ response = Message(role=ss.bot, message=response)
+ ss.chat.append(query)
+ ss.chat.append(response)
+ ss.messages.append(response)
+ ss.chat.append(documents)
+
+
+def page():
+ ss = st.session_state
+
+ if "messages" not in st.session_state:
+ ss.messages = []
+ if "chat" not in st.session_state:
+ ss.chat = []
+
+ display_chat()
+
+ query = st.chat_input("Enter query here")
+
+ if query:
+ generate_chat(query)
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ st.title("Retrieval Augmented Generation")
+ load_retriever()
+ sidebar()
+ page()