diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-29 00:53:39 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-29 00:53:39 +0200 | 
| commit | 716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch) | |
| tree | 778da9011d21051006fc206ce0978f0fc114b77b /rag/retriever/rerank | |
| parent | 2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff) | |
Wip memory
Diffstat (limited to 'rag/retriever/rerank')
| -rw-r--r-- | rag/retriever/rerank/local.py | 21 | 
1 files changed, 21 insertions, 0 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 75fedd8..8e94882 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,9 +1,11 @@  import os +from typing import List  from loguru import logger as log  from sentence_transformers import CrossEncoder  from rag.generator.prompt import Prompt +from rag.retriever.memory import Log  from rag.retriever.rerank.abstract import AbstractReranker @@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker):                  prompt.documents[r.get("corpus_id", 0)] for r in ranking              ]          return prompt + +    def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: +        if history: +            results = self.model.rank( +                query=prompt.query, +                documents=[m.bot.message for m in history], +                return_documents=False, +                top_k=self.top_k, +            ) +            ranking = list( +                filter( +                    lambda x: x.get("score", 0.0) > self.relevance_threshold, results +                ) +            ) +            log.debug( +                f"Reranking gave {len(ranking)} relevant messages of {len(history)}" +            ) +            history = [history[r.get("corpus_id", 0)] for r in ranking] +        return history  |