diff options
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r-- | rag/retriever/rerank/local.py | 45 |
1 files changed, 19 insertions, 26 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 231d50a..e2bef31 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,46 +1,39 @@ import os -from typing import List from loguru import logger as log from sentence_transformers import CrossEncoder -from rag.message import Message +from rag.message import Messages +from rag.retriever.encoder import Query from rag.retriever.rerank.abstract import AbstractReranker -from rag.retriever.vector import Document +from rag.retriever.vector import Documents + +Context = Documents | Messages class Reranker(metaclass=AbstractReranker): def __init__(self) -> None: self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu") self.top_k = int(os.environ["RERANK_TOP_K"]) - self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - - def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: - results = self.model.rank( - query=query, - documents=[d.text for d in documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" - ) - return [documents[r.get("corpus_id", 0)] for r in ranking] + self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"]) - def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + def rerank(self, query: Query, documents: Context) -> Context: results = self.model.rank( - query=query, - documents=[m.content for m in messages], + query=query.query, + documents=documents.content(), return_documents=False, top_k=self.top_k, ) - ranking = list( - filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + rankings = list( + map( + lambda x: x.get("corpus_id", 0), + filter( + lambda x: x.get("score", 0.0) > self.relevance_threshold, results + ), + ) ) log.debug( - f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + f"Reranking gave {len(rankings)} relevant documents of {len(documents)}" ) - return [messages[r.get("corpus_id", 0)] for r in ranking] + documents.rerank(rankings) + return documents |