summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/local.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r--rag/retriever/rerank/local.py70
1 files changed, 30 insertions, 40 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 8e94882..e727165 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -2,10 +2,10 @@ import os
from typing import List
from loguru import logger as log
+from rag.rag import Message
+from rag.retriever.vector import Document
from sentence_transformers import CrossEncoder
-from rag.generator.prompt import Prompt
-from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker):
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- results = self.model.rank(
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(
- lambda x: x.get("score", 0.0) > self.relevance_threshold, results
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
- )
- prompt.documents = [
- prompt.documents[r.get("corpus_id", 0)] for r in ranking
- ]
- return prompt
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
+ results = self.model.rank(
+ query=query,
+ documents=[d.text for d in documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
+ )
+ return [documents[r.get("corpus_id", 0)] for r in ranking]
- def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
- if history:
- results = self.model.rank(
- query=prompt.query,
- documents=[m.bot.message for m in history],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(
- lambda x: x.get("score", 0.0) > self.relevance_threshold, results
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
- )
- history = [history[r.get("corpus_id", 0)] for r in ranking]
- return history
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ results = self.model.rank(
+ query=query,
+ documents=[m.message for m in messages],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ )
+ return [messages[r.get("corpus_id", 0)] for r in ranking]