summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/local.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r--rag/retriever/rerank/local.py45
1 files changed, 19 insertions, 26 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 231d50a..e2bef31 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,46 +1,39 @@
import os
-from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
-from rag.message import Message
+from rag.message import Messages
+from rag.retriever.encoder import Query
from rag.retriever.rerank.abstract import AbstractReranker
-from rag.retriever.vector import Document
+from rag.retriever.vector import Documents
+
+Context = Documents | Messages
class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu")
self.top_k = int(os.environ["RERANK_TOP_K"])
- self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
-
- def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
- results = self.model.rank(
- query=query,
- documents=[d.text for d in documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
- )
- return [documents[r.get("corpus_id", 0)] for r in ranking]
+ self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"])
- def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ def rerank(self, query: Query, documents: Context) -> Context:
results = self.model.rank(
- query=query,
- documents=[m.content for m in messages],
+ query=query.query,
+ documents=documents.content(),
return_documents=False,
top_k=self.top_k,
)
- ranking = list(
- filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ rankings = list(
+ map(
+ lambda x: x.get("corpus_id", 0),
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ ),
+ )
)
log.debug(
- f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ f"Reranking gave {len(rankings)} relevant documents of {len(documents)}"
)
- return [messages[r.get("corpus_id", 0)] for r in ranking]
+ documents.rerank(rankings)
+ return documents