summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-13 02:26:01 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-13 02:26:01 +0200
commit72d1caf92115d90ae789de1cffed29406f2a0a39 (patch)
treeca4f3b755dccdd94894f18e7ce599cfd7eb28e58
parent36722903391ec42d5458112bc0549eb843548d90 (diff)
Wip chat ui
-rw-r--r--rag/generator/abstract.py8
-rw-r--r--rag/generator/cohere.py22
-rw-r--r--rag/generator/ollama.py9
-rw-r--r--rag/ui.py160
4 files changed, 156 insertions, 43 deletions
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 1beacfb..71edfc4 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List
from .prompt import Prompt
@@ -16,3 +16,9 @@ class AbstractGenerator(type):
@abstractmethod
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
pass
+
+ @abstractmethod
+ def chat(
+ self, prompt: Prompt, messages: List[Dict[str, str]]
+ ) -> Generator[Any, Any, Any]:
+ pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 2ed2cf5..16dfe88 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List
import cohere
from loguru import logger as log
@@ -14,7 +14,7 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
- log.debug("Generating answer from cohere")
+ log.debug("Generating answer from cohere...")
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
@@ -27,3 +27,21 @@ class Cohere(metaclass=AbstractGenerator):
yield event.citations
elif event.event_type == "stream-end":
yield event.finish_reason
+
+ def chat(
+ self, prompt: Prompt, messages: List[Dict[str, str]]
+ ) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer from cohere...")
+ query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
+ for event in self.client.chat_stream(
+ message=query,
+ documents=[asdict(d) for d in prompt.documents],
+ chat_history=messages,
+ prompt_truncation="AUTO",
+ ):
+ if event.event_type == "text-generation":
+ yield event.text
+ # elif event.event_type == "citation-generation":
+ # yield event.citations
+ elif event.event_type == "stream-end":
+ yield event.finish_reason
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 6340235..b475dcf 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Generator, List
+from typing import Any, Dict, Generator, List
import ollama
from loguru import logger as log
@@ -38,3 +38,10 @@ class Ollama(metaclass=AbstractGenerator):
metaprompt = self.__metaprompt(prompt)
for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True):
yield chunk["response"]
+
+ def chat(self, prompt: Prompt, messages: List[Dict[str, str]]) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer with ollama...")
+ metaprompt = self.__metaprompt(prompt)
+ messages.append({"role": "user", "content": metaprompt})
+ for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
+ yield chunk["message"]["content"]
diff --git a/rag/ui.py b/rag/ui.py
index 40da9dd..abaf284 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,38 +1,81 @@
-from typing import Type
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, List
+from loguru import logger as log
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders.blob_loaders import Blob
from rag.generator import MODELS, get_generator
-from rag.generator.abstract import AbstractGenerator
from rag.generator.prompt import Prompt
from rag.retriever.retriever import Retriever
+from rag.retriever.vector import Document
+
+
+class Cohere(Enum):
+ USER = "USER"
+ BOT = "CHATBOT"
+
+
+class Ollama(Enum):
+ USER = "user"
+ BOT = "assistant"
+
+
+@dataclass
+class Message:
+ role: str
+ message: str
+
+ def as_dict(self, client: str) -> Dict[str, str]:
+ if client == "cohere":
+ return {"role": self.role, "message": self.message}
+ else:
+ return {"role": self.role, "content": self.message}
+
+
+def set_chat_users():
+ log.debug("Setting user and bot value")
+ ss = st.session_state
+ if ss.generator == "cohere":
+ ss.user = Cohere.USER.value
+ ss.bot = Cohere.BOT.value
+ else:
+ ss.user = Ollama.USER.value
+ ss.bot = Ollama.BOT.value
+
+
+def clear_messages():
+ log.debug("Clearing llm chat history")
+ st.session_state.messages = []
@st.cache_resource
-def load_retriever() -> Retriever:
- return Retriever()
+def load_retriever():
+ log.debug("Loading retriever model")
+ st.session_state.retriever = Retriever()
@st.cache_resource
-def load_generator(model: str) -> Type[AbstractGenerator]:
- return get_generator(model)
+def load_generator(client: str):
+ log.debug("Loading generator model")
+ st.session_state.generator = get_generator(client)
+ set_chat_users()
+ clear_messages()
@st.cache_data(show_spinner=False)
def upload(files):
with st.spinner("Indexing documents..."):
+ retriever = st.session_state.retriever
for file in files:
source = file.name
blob = Blob.from_data(file.read())
retriever.add_pdf(blob=blob, source=source)
-if __name__ == "__main__":
- load_dotenv()
- retriever = load_retriever()
-
+def sidebar():
with st.sidebar:
st.header("Grouding")
st.markdown(
@@ -52,39 +95,78 @@ if __name__ == "__main__":
st.header("Generative Model")
st.markdown("Select the model that will be used for generating the answer.")
- model = st.selectbox("Generative Model", options=MODELS)
- generator = load_generator(model)
+ st.selectbox("Generative Model", key="client", options=MODELS)
+ load_generator(st.session_state.client)
- st.title("Retrieval Augmented Generation")
- with st.form(key="query"):
- query = st.text_area(
- "query",
- key="query",
- height=100,
- placeholder="Enter query here",
- help="",
- label_visibility="collapsed",
- disabled=False,
- )
- submit = st.form_submit_button("Generate")
+def display_context(documents: List[Document]):
+ with st.popover("See Context"):
+ for i, doc in enumerate(documents):
+ st.markdown(f"### Document {i}")
+ st.markdown(f"**Title: {doc.title}**")
+ st.markdown(doc.text)
+ st.markdown("---")
- (result_column, context_column) = st.columns(2)
- if submit and query:
- with st.spinner("Searching for documents..."):
- documents = retriever.retrieve(query)
+def display_chat():
+ ss = st.session_state
+ for msg in ss.chat:
+ if isinstance(msg, list):
+ display_context(msg)
+ else:
+ st.chat_message(msg.role).write(msg.message)
- prompt = Prompt(query, documents)
- with context_column:
- st.markdown("### Context")
- for i, doc in enumerate(documents):
- st.markdown(f"### Document {i}")
- st.markdown(f"**Title: {doc.title}**")
- st.markdown(doc.text)
- st.markdown("---")
+def generate_chat(query: str):
+ ss = st.session_state
+ with st.chat_message(ss.user):
+ st.write(query)
- with result_column:
- st.markdown("### Answer")
- st.write_stream(generator.generate(prompt))
+ retriever = ss.retriever
+ generator = ss.generator
+
+ with st.spinner("Searching for documents..."):
+ documents = retriever.retrieve(query)
+
+ prompt = Prompt(query, documents)
+
+ with st.chat_message(ss.bot):
+ history = [m.as_dict(ss.client) for m in ss.messages]
+ response = st.write_stream(generator.chat(prompt, history))
+ display_context(documents)
+ store_chat(query, response, documents)
+
+
+def store_chat(query: str, response: str, documents: List[Document]):
+ log.debug("Storing chat")
+ ss = st.session_state
+ query = Message(role=ss.user, message=query)
+ response = Message(role=ss.bot, message=response)
+ ss.chat.append(query)
+ ss.chat.append(response)
+ ss.messages.append(response)
+ ss.chat.append(documents)
+
+
+def page():
+ ss = st.session_state
+
+ if "messages" not in st.session_state:
+ ss.messages = []
+ if "chat" not in st.session_state:
+ ss.chat = []
+
+ display_chat()
+
+ query = st.chat_input("Enter query here")
+
+ if query:
+ generate_chat(query)
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ st.title("Retrieval Augmented Generation")
+ load_retriever()
+ sidebar()
+ page()