diff options
Diffstat (limited to 'rag/retriever/rerank')
| -rw-r--r-- | rag/retriever/rerank/local.py | 21 | 
1 files changed, 21 insertions, 0 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 75fedd8..8e94882 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,9 +1,11 @@  import os +from typing import List  from loguru import logger as log  from sentence_transformers import CrossEncoder  from rag.generator.prompt import Prompt +from rag.retriever.memory import Log  from rag.retriever.rerank.abstract import AbstractReranker @@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker):                  prompt.documents[r.get("corpus_id", 0)] for r in ranking              ]          return prompt + +    def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: +        if history: +            results = self.model.rank( +                query=prompt.query, +                documents=[m.bot.message for m in history], +                return_documents=False, +                top_k=self.top_k, +            ) +            ranking = list( +                filter( +                    lambda x: x.get("score", 0.0) > self.relevance_threshold, results +                ) +            ) +            log.debug( +                f"Reranking gave {len(ranking)} relevant messages of {len(history)}" +            ) +            history = [history[r.get("corpus_id", 0)] for r in ranking] +        return history  |