diff options
Diffstat (limited to 'rag/retriever/rerank')
| -rw-r--r-- | rag/retriever/rerank/local.py | 45 | 
1 files changed, 19 insertions, 26 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 231d50a..e2bef31 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,46 +1,39 @@  import os -from typing import List  from loguru import logger as log  from sentence_transformers import CrossEncoder -from rag.message import Message +from rag.message import Messages +from rag.retriever.encoder import Query  from rag.retriever.rerank.abstract import AbstractReranker -from rag.retriever.vector import Document +from rag.retriever.vector import Documents + +Context = Documents | Messages  class Reranker(metaclass=AbstractReranker):      def __init__(self) -> None:          self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu")          self.top_k = int(os.environ["RERANK_TOP_K"]) -        self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - -    def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: -        results = self.model.rank( -            query=query, -            documents=[d.text for d in documents], -            return_documents=False, -            top_k=self.top_k, -        ) -        ranking = list( -            filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) -        ) -        log.debug( -            f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" -        ) -        return [documents[r.get("corpus_id", 0)] for r in ranking] +        self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"]) -    def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: +    def rerank(self, query: Query, documents: Context) -> Context:          results = self.model.rank( -            query=query, -            documents=[m.content for m in messages], +            query=query.query, +            documents=documents.content(),              return_documents=False,              top_k=self.top_k,          ) -        ranking = list( -            filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) +        rankings = list( +            map( +                lambda x: x.get("corpus_id", 0), +                filter( +                    lambda x: x.get("score", 0.0) > self.relevance_threshold, results +                ), +            )          )          log.debug( -            f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" +            f"Reranking gave {len(rankings)} relevant documents of {len(documents)}"          ) -        return [messages[r.get("corpus_id", 0)] for r in ranking] +        documents.rerank(rankings) +        return documents  |