From b1ff0c55422d7b0af2c379679b8721014ef36926 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 18 Jun 2024 01:37:32 +0200 Subject: Wip rewrite --- rag/generator/abstract.py | 6 +++- rag/generator/cohere.py | 13 ++++---- rag/generator/ollama.py | 36 ++++++--------------- rag/generator/prompt.py | 24 ++++++++++++++ rag/rag.py | 69 +++++++++++++++++++++++++++++++++++++++ rag/retriever/memory.py | 51 ----------------------------- rag/retriever/rerank/abstract.py | 12 +++++-- rag/retriever/rerank/cohere.py | 55 ++++++++++++++++++++----------- rag/retriever/rerank/local.py | 70 +++++++++++++++++----------------------- rag/ui.py | 23 +++---------- 10 files changed, 196 insertions(+), 163 deletions(-) create mode 100644 rag/rag.py delete mode 100644 rag/retriever/memory.py diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 1beacfb..995e937 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Any, Generator +from rag.rag import Message + from .prompt import Prompt @@ -14,5 +16,7 @@ class AbstractGenerator(type): return cls._instances[cls] @abstractmethod - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + def generate( + self, prompt: Prompt, messages: List[Message] + ) -> Generator[Any, Any, Any]: pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index fb0cc5b..f30fe69 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,12 +1,14 @@ import os from dataclasses import asdict -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Generator, List import cohere from loguru import logger as log +from rag.rag import Message + from .abstract import AbstractGenerator -from .prompt import ANSWER_INSTRUCTION, Prompt +from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -14,14 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, history: Optional[List[Dict[str, str]]] + self, prompt: Prompt, messages: List[Message] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") - query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( - message=query, + message=prompt.to_str(), documents=[asdict(d) for d in prompt.documents], - chat_history=history, + chat_history=[m.as_dict() for m in messages], prompt_truncation="AUTO", ): if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index 9bf551a..ff5402b 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -4,10 +4,10 @@ from typing import Any, Generator, List import ollama from loguru import logger as log -from rag.retriever.vector import Document +from rag.rag import Message from .abstract import AbstractGenerator -from .prompt import ANSWER_INSTRUCTION, Prompt +from .prompt import Prompt class Ollama(metaclass=AbstractGenerator): @@ -15,29 +15,13 @@ class Ollama(metaclass=AbstractGenerator): self.model = os.environ["GENERATOR_MODEL"] log.debug(f"Using {self.model} for generator...") - def __context(self, documents: List[Document]) -> str: - results = [ - f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}" - for i, doc in enumerate(documents) - ] - return "\n".join(results) - - def __metaprompt(self, prompt: Prompt) -> str: - metaprompt = ( - "Context information is below.\n" - "---------------------\n" - f"{self.__context(prompt.documents)}\n\n" - "---------------------\n" - f"{ANSWER_INSTRUCTION}" - "Do not attempt to answer the query without relevant context and do not use" - " prior knowledge or training data!\n" - f"Query: {prompt.query.strip()}\n\n" - "Answer:" - ) - return metaprompt - - def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]: + def generate( + self, prompt: Prompt, messages: List[Message] + ) -> Generator[Any, Any, Any]: log.debug("Generating answer with ollama...") - metaprompt = self.__metaprompt(prompt) - for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True): + messages = messages.append( + Message(role="user", content=prompt.to_str(), client="ollama") + ) + messages = [m.as_dict() for m in messages] + for chunk in ollama.chat(model=self.model, messages=messages, stream=True): yield chunk["response"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 6523842..4840fdc 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -15,3 +15,27 @@ ANSWER_INSTRUCTION = ( class Prompt: query: str documents: List[Document] + generator_model: str + + def __context(self, documents: List[Document]) -> str: + results = [ + f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}" + for i, doc in enumerate(documents) + ] + return "\n".join(results) + + def to_str(self) -> str: + if self.generator_model == "cohere": + return f"{self.query}\n\n{ANSWER_INSTRUCTION}" + else: + return ( + "Context information is below.\n" + "---------------------\n" + f"{self.__context(self.documents)}\n\n" + "---------------------\n" + f"{ANSWER_INSTRUCTION}" + "Do not attempt to answer the query without relevant context and do not use" + " prior knowledge or training data!\n" + f"Query: {self.query.strip()}\n\n" + "Answer:" + ) diff --git a/rag/rag.py b/rag/rag.py new file mode 100644 index 0000000..1f6a176 --- /dev/null +++ b/rag/rag.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Any, Dict, Generator, List + +from loguru import logger as log + +from rag.generator import get_generator +from rag.generator.prompt import Prompt +from rag.retriever.rerank import get_reranker +from rag.retriever.retriever import Retriever +from rag.retriever.vector import Document + + +@dataclass +class Message: + role: str + content: str + client: str + + def as_dict(self) -> Dict[str, str]: + if self.client == "cohere": + return {"role": self.role, "message": self.content} + else: + return {"role": self.role, "content": self.content} + + +class Rag: + def __init__(self, client: str) -> None: + self.messages: List[Message] = [] + self.retriever = Retriever() + self.client = client + self.reranker = get_reranker(self.client) + self.generator = get_generator(self.client) + self.bot = "assistant" if self.client == "ollama" else "CHATBOT" + self.user = "user" if self.client == "ollama" else "USER" + + def __set_roles(self): + self.bot = "assistant" if self.client == "ollama" else "CHATBOT" + self.user = "user" if self.client == "ollama" else "USER" + + def set_client(self, client: str): + self.client = client + self.reranker = get_reranker(self.client) + self.generator = get_generator(self.client) + self.__set_roles() + self.__reset_messages() + log.debug(f"Swapped client to {self.client}") + + def __reset_messages(self): + log.debug("Deleting messages...") + self.messages = [] + + def retrieve(self, query: str) -> List[Document]: + documents = self.retriever.retrieve(query) + log.info(f"Found {len(documents)} relevant documents") + return self.reranker.rerank_documents(query, documents) + + def add_message(self, role: str, content: str): + self.messages.append( + Message(role=role, content=content, client=self.client) + ) + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + messages = self.reranker.rerank_messages(prompt.query, self.messages) + self.messages.append( + Message( + role=self.user, content=prompt.to_str(), client=self.client + ) + ) + return self.generator.generate(prompt, messages) diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py deleted file mode 100644 index c4455ed..0000000 --- a/rag/retriever/memory.py +++ /dev/null @@ -1,51 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List - - -@dataclass -class Log: - user: Message - bot: Message - - def get(): - return (user, bot) - - -@dataclass -class Message: - role: str - message: str - - def as_dict(self, model: str) -> Dict[str, str]: - if model == "cohere": - match self.role: - case "user": - role = "USER" - case _: - role = "CHATBOT" - - return {"role": role, "message": self.message} - else: - return {"role": self.role, "content": self.message} - - -class Memory: - def __init__(self, reranker) -> None: - self.history = [] - self.reranker = reranker - self.user = "user" - self.bot = "assistant" - - def add(self, prompt: str, response: str): - self.history.append( - Log( - user=Message(role=self.user, message=prompt), - bot=Message(role=self.bot, message=response), - ) - ) - - def get(self) -> List[Log]: - return [m.as_dict() for log in self.history for m in log.get()] - - def reset(self): - self.history = [] diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index b96b70a..f32ee77 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,6 +1,8 @@ from abc import abstractmethod +from typing import List -from rag.generator.prompt import Prompt +from rag.memory import Message +from rag.retriever.vector import Document class AbstractReranker(type): @@ -13,5 +15,9 @@ class AbstractReranker(type): return cls._instances[cls] @abstractmethod - def rank(self, prompt: Prompt) -> Prompt: - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]: + pass + + @abstractmethod + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + pass diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 43690a1..33c373d 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -1,10 +1,12 @@ import os +from typing import List import cohere from loguru import logger as log -from rag.generator.prompt import Prompt +from rag.rag import Message from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document class CohereReranker(metaclass=AbstractReranker): @@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=self.top_k, + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + response = self.client.rerank( + model="rerank-english-v3.0", + query=query, + documents=[d.text for d in documents], + top_n=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - ranking = list( - filter( - lambda x: x.relevance_score > self.relevance_threshold, - response.results, - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.index] for r in ranking] + + + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + response = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - prompt.documents = [prompt.documents[r.index] for r in ranking] - return prompt + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.index] for r in ranking] diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 8e94882..e727165 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,10 +2,10 @@ import os from typing import List from loguru import logger as log +from rag.rag import Message +from rag.retriever.vector import Document from sentence_transformers import CrossEncoder -from rag.generator.prompt import Prompt -from rag.retriever.memory import Log from rag.retriever.rerank.abstract import AbstractReranker @@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - results = self.model.rank( - query=prompt.query, - documents=[d.text for d in prompt.documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" - ) - prompt.documents = [ - prompt.documents[r.get("corpus_id", 0)] for r in ranking - ] - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + results = self.model.rank( + query=query, + documents=[d.text for d in documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.get("corpus_id", 0)] for r in ranking] - def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: - if history: - results = self.model.rank( - query=prompt.query, - documents=[m.bot.message for m in history], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant messages of {len(history)}" - ) - history = [history[r.get("corpus_id", 0)] for r in ranking] - return history + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + results = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.get("corpus_id", 0)] for r in ranking] diff --git a/rag/ui.py b/rag/ui.py index ddb3d78..a453f47 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -20,15 +20,9 @@ def set_chat_users(): ss.bot = "assistant" -@st.cache_resource -def load_retriever(): - log.debug("Loading retriever model") - st.session_state.retriever = Retriever() - - def load_generator(model: str): - log.debug("Loading generator model") - st.session_state.generator = get_generator(model) + log.debug("Loading rag") + st.session_state.rag = get_generator(model) def load_reranker(model: str): @@ -70,17 +64,10 @@ def generate_chat(query: str): with st.chat_message(ss.user): st.write(query) - retriever = ss.retriever - generator = ss.generator - reranker = ss.reranker - - documents = retriever.retrieve(query) - prompt = Prompt(query, documents) - - prompt = reranker.rank(prompt) - + rag = ss.rag + prompt = rag.retrieve(query) with st.chat_message(ss.bot): - response = st.write_stream(generator.generate(prompt)) + response = st.write_stream(rag.generate(query)) display_context(prompt.documents) store_chat(query, response, prompt.documents) -- cgit v1.2.3-70-g09d2