summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/local.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-08 21:14:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-08 21:14:27 +0200
commit7ab76ac97024eff6cbe559cc158e840de01a39a8 (patch)
treee9b0ce4ec5cf4f2d6801cd9068cdb736f1650200 /rag/retriever/rerank/local.py
parent441d38e5be98fc7f060f1221181cc9f8c130cdba (diff)
Set relevance threshold as a env
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r--rag/retriever/rerank/local.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 758c5dc..75fedd8 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -11,6 +11,7 @@ class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.model = CrossEncoder(os.environ["RERANK_MODEL"])
self.top_k = int(os.environ["RERANK_TOP_K"])
+ self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
def rank(self, prompt: Prompt) -> Prompt:
if prompt.documents:
@@ -20,7 +21,11 @@ class Reranker(metaclass=AbstractReranker):
return_documents=False,
top_k=self.top_k,
)
- ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
log.debug(
f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
)