From 72d1caf92115d90ae789de1cffed29406f2a0a39 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 13 Apr 2024 02:26:01 +0200 Subject: Wip chat ui --- rag/generator/abstract.py | 8 ++- rag/generator/cohere.py | 22 ++++++- rag/generator/ollama.py | 9 ++- rag/ui.py | 160 +++++++++++++++++++++++++++++++++++----------- 4 files changed, 156 insertions(+), 43 deletions(-) diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 1beacfb..71edfc4 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Generator +from typing import Any, Dict, Generator, List from .prompt import Prompt @@ -16,3 +16,9 @@ class AbstractGenerator(type): @abstractmethod def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: pass + + @abstractmethod + def chat( + self, prompt: Prompt, messages: List[Dict[str, str]] + ) -> Generator[Any, Any, Any]: + pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 2ed2cf5..16dfe88 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict -from typing import Any, Generator +from typing import Any, Dict, Generator, List import cohere from loguru import logger as log @@ -14,7 +14,7 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - log.debug("Generating answer from cohere") + log.debug("Generating answer from cohere...") query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, @@ -27,3 +27,21 @@ class Cohere(metaclass=AbstractGenerator): yield event.citations elif event.event_type == "stream-end": yield event.finish_reason + + def chat( + self, prompt: Prompt, messages: List[Dict[str, str]] + ) -> Generator[Any, Any, Any]: + log.debug("Generating answer from cohere...") + query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" + for event in self.client.chat_stream( + message=query, + documents=[asdict(d) for d in prompt.documents], + chat_history=messages, + prompt_truncation="AUTO", + ): + if event.event_type == "text-generation": + yield event.text + # elif event.event_type == "citation-generation": + # yield event.citations + elif event.event_type == "stream-end": + yield event.finish_reason diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index 6340235..b475dcf 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -1,5 +1,5 @@ import os -from typing import Any, Generator, List +from typing import Any, Dict, Generator, List import ollama from loguru import logger as log @@ -38,3 +38,10 @@ class Ollama(metaclass=AbstractGenerator): metaprompt = self.__metaprompt(prompt) for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True): yield chunk["response"] + + def chat(self, prompt: Prompt, messages: List[Dict[str, str]]) -> Generator[Any, Any, Any]: + log.debug("Generating answer with ollama...") + metaprompt = self.__metaprompt(prompt) + messages.append({"role": "user", "content": metaprompt}) + for chunk in ollama.chat(model=self.model, messages=messages, stream=True): + yield chunk["message"]["content"] diff --git a/rag/ui.py b/rag/ui.py index 40da9dd..abaf284 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,38 +1,81 @@ -from typing import Type +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List +from loguru import logger as log import streamlit as st from dotenv import load_dotenv from langchain_community.document_loaders.blob_loaders import Blob from rag.generator import MODELS, get_generator -from rag.generator.abstract import AbstractGenerator from rag.generator.prompt import Prompt from rag.retriever.retriever import Retriever +from rag.retriever.vector import Document + + +class Cohere(Enum): + USER = "USER" + BOT = "CHATBOT" + + +class Ollama(Enum): + USER = "user" + BOT = "assistant" + + +@dataclass +class Message: + role: str + message: str + + def as_dict(self, client: str) -> Dict[str, str]: + if client == "cohere": + return {"role": self.role, "message": self.message} + else: + return {"role": self.role, "content": self.message} + + +def set_chat_users(): + log.debug("Setting user and bot value") + ss = st.session_state + if ss.generator == "cohere": + ss.user = Cohere.USER.value + ss.bot = Cohere.BOT.value + else: + ss.user = Ollama.USER.value + ss.bot = Ollama.BOT.value + + +def clear_messages(): + log.debug("Clearing llm chat history") + st.session_state.messages = [] @st.cache_resource -def load_retriever() -> Retriever: - return Retriever() +def load_retriever(): + log.debug("Loading retriever model") + st.session_state.retriever = Retriever() @st.cache_resource -def load_generator(model: str) -> Type[AbstractGenerator]: - return get_generator(model) +def load_generator(client: str): + log.debug("Loading generator model") + st.session_state.generator = get_generator(client) + set_chat_users() + clear_messages() @st.cache_data(show_spinner=False) def upload(files): with st.spinner("Indexing documents..."): + retriever = st.session_state.retriever for file in files: source = file.name blob = Blob.from_data(file.read()) retriever.add_pdf(blob=blob, source=source) -if __name__ == "__main__": - load_dotenv() - retriever = load_retriever() - +def sidebar(): with st.sidebar: st.header("Grouding") st.markdown( @@ -52,39 +95,78 @@ if __name__ == "__main__": st.header("Generative Model") st.markdown("Select the model that will be used for generating the answer.") - model = st.selectbox("Generative Model", options=MODELS) - generator = load_generator(model) + st.selectbox("Generative Model", key="client", options=MODELS) + load_generator(st.session_state.client) - st.title("Retrieval Augmented Generation") - with st.form(key="query"): - query = st.text_area( - "query", - key="query", - height=100, - placeholder="Enter query here", - help="", - label_visibility="collapsed", - disabled=False, - ) - submit = st.form_submit_button("Generate") +def display_context(documents: List[Document]): + with st.popover("See Context"): + for i, doc in enumerate(documents): + st.markdown(f"### Document {i}") + st.markdown(f"**Title: {doc.title}**") + st.markdown(doc.text) + st.markdown("---") - (result_column, context_column) = st.columns(2) - if submit and query: - with st.spinner("Searching for documents..."): - documents = retriever.retrieve(query) +def display_chat(): + ss = st.session_state + for msg in ss.chat: + if isinstance(msg, list): + display_context(msg) + else: + st.chat_message(msg.role).write(msg.message) - prompt = Prompt(query, documents) - with context_column: - st.markdown("### Context") - for i, doc in enumerate(documents): - st.markdown(f"### Document {i}") - st.markdown(f"**Title: {doc.title}**") - st.markdown(doc.text) - st.markdown("---") +def generate_chat(query: str): + ss = st.session_state + with st.chat_message(ss.user): + st.write(query) - with result_column: - st.markdown("### Answer") - st.write_stream(generator.generate(prompt)) + retriever = ss.retriever + generator = ss.generator + + with st.spinner("Searching for documents..."): + documents = retriever.retrieve(query) + + prompt = Prompt(query, documents) + + with st.chat_message(ss.bot): + history = [m.as_dict(ss.client) for m in ss.messages] + response = st.write_stream(generator.chat(prompt, history)) + display_context(documents) + store_chat(query, response, documents) + + +def store_chat(query: str, response: str, documents: List[Document]): + log.debug("Storing chat") + ss = st.session_state + query = Message(role=ss.user, message=query) + response = Message(role=ss.bot, message=response) + ss.chat.append(query) + ss.chat.append(response) + ss.messages.append(response) + ss.chat.append(documents) + + +def page(): + ss = st.session_state + + if "messages" not in st.session_state: + ss.messages = [] + if "chat" not in st.session_state: + ss.chat = [] + + display_chat() + + query = st.chat_input("Enter query here") + + if query: + generate_chat(query) + + +if __name__ == "__main__": + load_dotenv() + st.title("Retrieval Augmented Generation") + load_retriever() + sidebar() + page() -- cgit v1.2.3-70-g09d2