diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-23 00:49:36 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-23 00:49:36 +0200 |
commit | d6ce5be4889e462fa69968feb90f8ebad30cdb0f (patch) | |
tree | 4a10d4c7654b1a72c7eca50a64adad3e1593ef06 /rag/generator/cohere.py | |
parent | 03a5a027db56932d17ae04f4054895f070d955d0 (diff) |
Add reranking
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r-- | rag/generator/cohere.py | 35 |
1 files changed, 16 insertions, 19 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 16dfe88..3499e3b 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict -from typing import Any, Dict, Generator, List +from typing import Any, Generator import cohere from loguru import logger as log @@ -13,8 +13,23 @@ class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + def rerank(self, prompt: Prompt): + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=3, + ) + ranking = list( + filter(lambda x: x.relevance_score > 0.5, response.results) + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] + return prompt + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") + if prompt.documents: + prompt = self.rerank(prompt) query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, @@ -27,21 +42,3 @@ class Cohere(metaclass=AbstractGenerator): yield event.citations elif event.event_type == "stream-end": yield event.finish_reason - - def chat( - self, prompt: Prompt, messages: List[Dict[str, str]] - ) -> Generator[Any, Any, Any]: - log.debug("Generating answer from cohere...") - query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" - for event in self.client.chat_stream( - message=query, - documents=[asdict(d) for d in prompt.documents], - chat_history=messages, - prompt_truncation="AUTO", - ): - if event.event_type == "text-generation": - yield event.text - # elif event.event_type == "citation-generation": - # yield event.citations - elif event.event_type == "stream-end": - yield event.finish_reason |