From b1ff0c55422d7b0af2c379679b8721014ef36926 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 18 Jun 2024 01:37:32 +0200 Subject: Wip rewrite --- rag/retriever/memory.py | 51 ----------------------------- rag/retriever/rerank/abstract.py | 12 +++++-- rag/retriever/rerank/cohere.py | 55 ++++++++++++++++++++----------- rag/retriever/rerank/local.py | 70 +++++++++++++++++----------------------- 4 files changed, 76 insertions(+), 112 deletions(-) delete mode 100644 rag/retriever/memory.py (limited to 'rag/retriever') diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py deleted file mode 100644 index c4455ed..0000000 --- a/rag/retriever/memory.py +++ /dev/null @@ -1,51 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List - - -@dataclass -class Log: - user: Message - bot: Message - - def get(): - return (user, bot) - - -@dataclass -class Message: - role: str - message: str - - def as_dict(self, model: str) -> Dict[str, str]: - if model == "cohere": - match self.role: - case "user": - role = "USER" - case _: - role = "CHATBOT" - - return {"role": role, "message": self.message} - else: - return {"role": self.role, "content": self.message} - - -class Memory: - def __init__(self, reranker) -> None: - self.history = [] - self.reranker = reranker - self.user = "user" - self.bot = "assistant" - - def add(self, prompt: str, response: str): - self.history.append( - Log( - user=Message(role=self.user, message=prompt), - bot=Message(role=self.bot, message=response), - ) - ) - - def get(self) -> List[Log]: - return [m.as_dict() for log in self.history for m in log.get()] - - def reset(self): - self.history = [] diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index b96b70a..f32ee77 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,6 +1,8 @@ from abc import abstractmethod +from typing import List -from rag.generator.prompt import Prompt +from rag.memory import Message +from rag.retriever.vector import Document class AbstractReranker(type): @@ -13,5 +15,9 @@ class AbstractReranker(type): return cls._instances[cls] @abstractmethod - def rank(self, prompt: Prompt) -> Prompt: - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]: + pass + + @abstractmethod + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + pass diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 43690a1..33c373d 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -1,10 +1,12 @@ import os +from typing import List import cohere from loguru import logger as log -from rag.generator.prompt import Prompt +from rag.rag import Message from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document class CohereReranker(metaclass=AbstractReranker): @@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=self.top_k, + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + response = self.client.rerank( + model="rerank-english-v3.0", + query=query, + documents=[d.text for d in documents], + top_n=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - ranking = list( - filter( - lambda x: x.relevance_score > self.relevance_threshold, - response.results, - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.index] for r in ranking] + + + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + response = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - prompt.documents = [prompt.documents[r.index] for r in ranking] - return prompt + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.index] for r in ranking] diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 8e94882..e727165 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,10 +2,10 @@ import os from typing import List from loguru import logger as log +from rag.rag import Message +from rag.retriever.vector import Document from sentence_transformers import CrossEncoder -from rag.generator.prompt import Prompt -from rag.retriever.memory import Log from rag.retriever.rerank.abstract import AbstractReranker @@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - results = self.model.rank( - query=prompt.query, - documents=[d.text for d in prompt.documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" - ) - prompt.documents = [ - prompt.documents[r.get("corpus_id", 0)] for r in ranking - ] - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + results = self.model.rank( + query=query, + documents=[d.text for d in documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.get("corpus_id", 0)] for r in ranking] - def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: - if history: - results = self.model.rank( - query=prompt.query, - documents=[m.bot.message for m in history], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant messages of {len(history)}" - ) - history = [history[r.get("corpus_id", 0)] for r in ranking] - return history + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + results = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.get("corpus_id", 0)] for r in ranking] -- cgit v1.2.3-70-g09d2