summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/local.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 09:09:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 09:09:24 +0200
commit9e0cbcb4e7f1f3f95f304046d3190c6ebc4d3901 (patch)
tree5d890ce2705b79f23d63988c140d08edadaf35c5 /rag/retriever/rerank/local.py
parent2e85325639ce3827cc2eb32f9750dfa873e3a480 (diff)
Reformat and fix typo
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r--rag/retriever/rerank/local.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
new file mode 100644
index 0000000..758c5dc
--- /dev/null
+++ b/rag/retriever/rerank/local.py
@@ -0,0 +1,30 @@
+import os
+
+from loguru import logger as log
+from sentence_transformers import CrossEncoder
+
+from rag.generator.prompt import Prompt
+from rag.retriever.rerank.abstract import AbstractReranker
+
+
+class Reranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.model = CrossEncoder(os.environ["RERANK_MODEL"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [
+ prompt.documents[r.get("corpus_id", 0)] for r in ranking
+ ]
+ return prompt