summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
commit716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch)
tree778da9011d21051006fc206ce0978f0fc114b77b /rag/retriever/rerank
parent2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff)
Wip memory
Diffstat (limited to 'rag/retriever/rerank')
-rw-r--r--rag/retriever/rerank/local.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 75fedd8..8e94882 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,9 +1,11 @@
import os
+from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
from rag.generator.prompt import Prompt
+from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker):
prompt.documents[r.get("corpus_id", 0)] for r in ranking
]
return prompt
+
+ def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
+ if history:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[m.bot.message for m in history],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
+ )
+ history = [history[r.get("corpus_id", 0)] for r in ranking]
+ return history