From b1ff0c55422d7b0af2c379679b8721014ef36926 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 18 Jun 2024 01:37:32 +0200 Subject: Wip rewrite --- rag/retriever/rerank/abstract.py | 12 +++++-- rag/retriever/rerank/cohere.py | 55 ++++++++++++++++++++----------- rag/retriever/rerank/local.py | 70 +++++++++++++++++----------------------- 3 files changed, 76 insertions(+), 61 deletions(-) (limited to 'rag/retriever/rerank') diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index b96b70a..f32ee77 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,6 +1,8 @@ from abc import abstractmethod +from typing import List -from rag.generator.prompt import Prompt +from rag.memory import Message +from rag.retriever.vector import Document class AbstractReranker(type): @@ -13,5 +15,9 @@ class AbstractReranker(type): return cls._instances[cls] @abstractmethod - def rank(self, prompt: Prompt) -> Prompt: - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]: + pass + + @abstractmethod + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + pass diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 43690a1..33c373d 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -1,10 +1,12 @@ import os +from typing import List import cohere from loguru import logger as log -from rag.generator.prompt import Prompt +from rag.rag import Message from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document class CohereReranker(metaclass=AbstractReranker): @@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=self.top_k, + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + response = self.client.rerank( + model="rerank-english-v3.0", + query=query, + documents=[d.text for d in documents], + top_n=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - ranking = list( - filter( - lambda x: x.relevance_score > self.relevance_threshold, - response.results, - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.index] for r in ranking] + + + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + response = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - prompt.documents = [prompt.documents[r.index] for r in ranking] - return prompt + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.index] for r in ranking] diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 8e94882..e727165 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,10 +2,10 @@ import os from typing import List from loguru import logger as log +from rag.rag import Message +from rag.retriever.vector import Document from sentence_transformers import CrossEncoder -from rag.generator.prompt import Prompt -from rag.retriever.memory import Log from rag.retriever.rerank.abstract import AbstractReranker @@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - results = self.model.rank( - query=prompt.query, - documents=[d.text for d in prompt.documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" - ) - prompt.documents = [ - prompt.documents[r.get("corpus_id", 0)] for r in ranking - ] - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + results = self.model.rank( + query=query, + documents=[d.text for d in documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.get("corpus_id", 0)] for r in ranking] - def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: - if history: - results = self.model.rank( - query=prompt.query, - documents=[m.bot.message for m in history], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant messages of {len(history)}" - ) - history = [history[r.get("corpus_id", 0)] for r in ranking] - return history + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + results = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.get("corpus_id", 0)] for r in ranking] -- cgit v1.2.3-70-g09d2