summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/cohere.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/rerank/cohere.py')
-rw-r--r--rag/retriever/rerank/cohere.py55
1 files changed, 37 insertions, 18 deletions
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
index 43690a1..33c373d 100644
--- a/rag/retriever/rerank/cohere.py
+++ b/rag/retriever/rerank/cohere.py
@@ -1,10 +1,12 @@
import os
+from typing import List
import cohere
from loguru import logger as log
-from rag.generator.prompt import Prompt
+from rag.rag import Message
from rag.retriever.rerank.abstract import AbstractReranker
+from rag.retriever.vector import Document
class CohereReranker(metaclass=AbstractReranker):
@@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker):
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=self.top_k,
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=query,
+ documents=[d.text for d in documents],
+ top_n=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
)
- ranking = list(
- filter(
- lambda x: x.relevance_score > self.relevance_threshold,
- response.results,
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
+ )
+ return [documents[r.index] for r in ranking]
+
+
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ response = self.model.rank(
+ query=query,
+ documents=[m.message for m in messages],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
)
- prompt.documents = [prompt.documents[r.index] for r in ranking]
- return prompt
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ )
+ return [messages[r.index] for r in ranking]