From b1ff0c55422d7b0af2c379679b8721014ef36926 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 18 Jun 2024 01:37:32 +0200 Subject: Wip rewrite --- rag/retriever/rerank/cohere.py | 55 ++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) (limited to 'rag/retriever/rerank/cohere.py') diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 43690a1..33c373d 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -1,10 +1,12 @@ import os +from typing import List import cohere from loguru import logger as log -from rag.generator.prompt import Prompt +from rag.rag import Message from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document class CohereReranker(metaclass=AbstractReranker): @@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker): self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=self.top_k, + def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: + response = self.client.rerank( + model="rerank-english-v3.0", + query=query, + documents=[d.text for d in documents], + top_n=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - ranking = list( - filter( - lambda x: x.relevance_score > self.relevance_threshold, - response.results, - ) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" + ) + return [documents[r.index] for r in ranking] + + + def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + response = self.model.rank( + query=query, + documents=[m.message for m in messages], + return_documents=False, + top_k=self.top_k, + ) + ranking = list( + filter( + lambda x: x.relevance_score > self.relevance_threshold, + response.results, ) - prompt.documents = [prompt.documents[r.index] for r in ranking] - return prompt + ) + log.debug( + f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + ) + return [messages[r.index] for r in ranking] -- cgit v1.2.3-70-g09d2