From 716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 29 May 2024 00:53:39 +0200 Subject: Wip memory --- rag/retriever/rerank/local.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'rag/retriever/rerank/local.py') diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 75fedd8..8e94882 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,9 +1,11 @@ import os +from typing import List from loguru import logger as log from sentence_transformers import CrossEncoder from rag.generator.prompt import Prompt +from rag.retriever.memory import Log from rag.retriever.rerank.abstract import AbstractReranker @@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker): prompt.documents[r.get("corpus_id", 0)] for r in ranking ] return prompt + + def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: + if history: + results = self.model.rank( + query=prompt.query, + documents=[m.bot.message for m in history], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter( + lambda x: x.get("score", 0.0) > self.relevance_threshold, results + ) + ) + log.debug( + f"Reranking gave {len(ranking)} relevant messages of {len(history)}" + ) + history = [history[r.get("corpus_id", 0)] for r in ranking] + return history -- cgit v1.2.3-70-g09d2