From 716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 29 May 2024 00:53:39 +0200 Subject: Wip memory --- rag/generator/cohere.py | 7 ++++-- rag/generator/ollama.py | 4 ++-- rag/retriever/memory.py | 51 +++++++++++++++++++++++++++++++++++++++++++ rag/retriever/rerank/local.py | 21 ++++++++++++++++++ rag/ui.py | 14 ------------ 5 files changed, 79 insertions(+), 18 deletions(-) create mode 100644 rag/retriever/memory.py (limited to 'rag') diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 28a87e7..fb0cc5b 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict -from typing import Any, Generator +from typing import Any, Dict, Generator, List, Optional import cohere from loguru import logger as log @@ -13,12 +13,15 @@ class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + def generate( + self, prompt: Prompt, history: Optional[List[Dict[str, str]]] + ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, documents=[asdict(d) for d in prompt.documents], + chat_history=history, prompt_truncation="AUTO", ): if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index 52521ca..9bf551a 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -36,8 +36,8 @@ class Ollama(metaclass=AbstractGenerator): ) return metaprompt - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]: log.debug("Generating answer with ollama...") metaprompt = self.__metaprompt(prompt) - for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True): + for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True): yield chunk["response"] diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py new file mode 100644 index 0000000..c4455ed --- /dev/null +++ b/rag/retriever/memory.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class Log: + user: Message + bot: Message + + def get(): + return (user, bot) + + +@dataclass +class Message: + role: str + message: str + + def as_dict(self, model: str) -> Dict[str, str]: + if model == "cohere": + match self.role: + case "user": + role = "USER" + case _: + role = "CHATBOT" + + return {"role": role, "message": self.message} + else: + return {"role": self.role, "content": self.message} + + +class Memory: + def __init__(self, reranker) -> None: + self.history = [] + self.reranker = reranker + self.user = "user" + self.bot = "assistant" + + def add(self, prompt: str, response: str): + self.history.append( + Log( + user=Message(role=self.user, message=prompt), + bot=Message(role=self.bot, message=response), + ) + ) + + def get(self) -> List[Log]: + return [m.as_dict() for log in self.history for m in log.get()] + + def reset(self): + self.history = [] diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 75fedd8..8e94882 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,9 +1,11 @@ import os +from typing import List from loguru import logger as log from sentence_transformers import CrossEncoder from rag.generator.prompt import Prompt +from rag.retriever.memory import Log from rag.retriever.rerank.abstract import AbstractReranker @@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker): prompt.documents[r.get("corpus_id", 0)] for r in ranking ] return prompt + + def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: + if history: + results = self.model.rank( + query=prompt.query, + documents=[m.bot.message for m in history], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter( + lambda x: x.get("score", 0.0) > self.relevance_threshold, results + ) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant messages of {len(history)}" + ) + history = [history[r.get("corpus_id", 0)] for r in ranking] + return history diff --git a/rag/ui.py b/rag/ui.py index 36e8c4c..ddb3d78 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -13,18 +13,6 @@ from rag.retriever.retriever import Retriever from rag.retriever.vector import Document -@dataclass -class Message: - role: str - message: str - - def as_dict(self, model: str) -> Dict[str, str]: - if model == "cohere": - return {"role": self.role, "message": self.message} - else: - return {"role": self.role, "content": self.message} - - def set_chat_users(): log.debug("Setting user and bot value") ss = st.session_state @@ -38,13 +26,11 @@ def load_retriever(): st.session_state.retriever = Retriever() -# @st.cache_resource def load_generator(model: str): log.debug("Loading generator model") st.session_state.generator = get_generator(model) -# @st.cache_resource def load_reranker(model: str): log.debug("Loading reranker model") st.session_state.reranker = get_reranker(model) -- cgit v1.2.3-70-g09d2