summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-08 21:14:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-08 21:14:27 +0200
commit7ab76ac97024eff6cbe559cc158e840de01a39a8 (patch)
treee9b0ce4ec5cf4f2d6801cd9068cdb736f1650200
parent441d38e5be98fc7f060f1221181cc9f8c130cdba (diff)
Set relevance threshold as a env
-rw-r--r--rag/retriever/rerank/cohere.py8
-rw-r--r--rag/retriever/rerank/local.py7
2 files changed, 13 insertions, 2 deletions
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
index dac9ab5..43690a1 100644
--- a/rag/retriever/rerank/cohere.py
+++ b/rag/retriever/rerank/cohere.py
@@ -11,6 +11,7 @@ class CohereReranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
self.top_k = int(os.environ["RERANK_TOP_K"])
+ self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
def rank(self, prompt: Prompt) -> Prompt:
if prompt.documents:
@@ -20,7 +21,12 @@ class CohereReranker(metaclass=AbstractReranker):
documents=[d.text for d in prompt.documents],
top_n=self.top_k,
)
- ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
+ )
+ )
log.debug(
f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
)
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 758c5dc..75fedd8 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -11,6 +11,7 @@ class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.model = CrossEncoder(os.environ["RERANK_MODEL"])
self.top_k = int(os.environ["RERANK_TOP_K"])
+ self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
def rank(self, prompt: Prompt) -> Prompt:
if prompt.documents:
@@ -20,7 +21,11 @@ class Reranker(metaclass=AbstractReranker):
return_documents=False,
top_k=self.top_k,
)
- ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
log.debug(
f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
)