summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/local.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/rerank/local.py')
-rw-r--r--rag/retriever/rerank/local.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 758c5dc..75fedd8 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -11,6 +11,7 @@ class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.model = CrossEncoder(os.environ["RERANK_MODEL"])
self.top_k = int(os.environ["RERANK_TOP_K"])
+ self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
def rank(self, prompt: Prompt) -> Prompt:
if prompt.documents:
@@ -20,7 +21,11 @@ class Reranker(metaclass=AbstractReranker):
return_documents=False,
top_k=self.top_k,
)
- ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
log.debug(
f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
)