diff options
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r-- | rag/generator/cohere.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 3499e3b..049aea4 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -13,23 +13,23 @@ class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - def rerank(self, prompt: Prompt): - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=3, - ) - ranking = list( - filter(lambda x: x.relevance_score > 0.5, response.results) - ) - prompt.documents = [prompt.documents[r.index] for r in ranking] + def rerank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=3, + ) + ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] return prompt def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") - if prompt.documents: - prompt = self.rerank(prompt) query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, |