diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-08 21:14:27 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-08 21:14:27 +0200 |
commit | 7ab76ac97024eff6cbe559cc158e840de01a39a8 (patch) | |
tree | e9b0ce4ec5cf4f2d6801cd9068cdb736f1650200 /rag/retriever | |
parent | 441d38e5be98fc7f060f1221181cc9f8c130cdba (diff) |
Set relevance threshold as a env
Diffstat (limited to 'rag/retriever')
-rw-r--r-- | rag/retriever/rerank/cohere.py | 8 | ||||
-rw-r--r-- | rag/retriever/rerank/local.py | 7 |
2 files changed, 13 insertions, 2 deletions
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index dac9ab5..43690a1 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -11,6 +11,7 @@ class CohereReranker(metaclass=AbstractReranker): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) self.top_k = int(os.environ["RERANK_TOP_K"]) + self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) def rank(self, prompt: Prompt) -> Prompt: if prompt.documents: @@ -20,7 +21,12 @@ class CohereReranker(metaclass=AbstractReranker): documents=[d.text for d in prompt.documents], top_n=self.top_k, ) - ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, + ) + ) log.debug( f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" ) diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 758c5dc..75fedd8 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -11,6 +11,7 @@ class Reranker(metaclass=AbstractReranker): def __init__(self) -> None: self.model = CrossEncoder(os.environ["RERANK_MODEL"]) self.top_k = int(os.environ["RERANK_TOP_K"]) + self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) def rank(self, prompt: Prompt) -> Prompt: if prompt.documents: @@ -20,7 +21,11 @@ class Reranker(metaclass=AbstractReranker): return_documents=False, top_k=self.top_k, ) - ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results)) + ranking = list( + filter( + lambda x: x.get("score", 0.0) > self.relevance_threshold, results + ) + ) log.debug( f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" ) |