diff options
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r-- | rag/retriever/rerank/local.py | 70 |
1 files changed, 30 insertions, 40 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 8e94882..e727165 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,10 +2,10 @@ import os from typing import List from loguru import logger as log +from rag.rag import Message +from rag.retriever.vector import Document from sentence_transformers import CrossEncoder -from rag.generator.prompt import Prompt -from rag.retriever.memory import Log from rag.retriever.rerank.abstract import AbstractReranker @@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - results = self.model.rank( - query=prompt.query, - documents=[d.text for d in prompt.documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" - ) - prompt.documents = [ - prompt.documents[r.get("corpus_id", 0)] for r in ranking - ] - return prompt + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + results = self.model.rank( + query=query, + documents=[d.text for d in documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.get("corpus_id", 0)] for r in ranking] - def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: - if history: - results = self.model.rank( - query=prompt.query, - documents=[m.bot.message for m in history], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter( - lambda x: x.get("score", 0.0) > self.relevance_threshold, results - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant messages of {len(history)}" - ) - history = [history[r.get("corpus_id", 0)] for r in ranking] - return history + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + results = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.get("corpus_id", 0)] for r in ranking] |