summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/cohere.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/rerank/cohere.py')
-rw-r--r--rag/retriever/rerank/cohere.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
new file mode 100644
index 0000000..dac9ab5
--- /dev/null
+++ b/rag/retriever/rerank/cohere.py
@@ -0,0 +1,28 @@
+import os
+
+import cohere
+from loguru import logger as log
+
+from rag.generator.prompt import Prompt
+from rag.retriever.rerank.abstract import AbstractReranker
+
+
+class CohereReranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
+ return prompt