summaryrefslogtreecommitdiff
path: root/rag/generator/cohere.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-23 00:49:36 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-23 00:49:36 +0200
commitd6ce5be4889e462fa69968feb90f8ebad30cdb0f (patch)
tree4a10d4c7654b1a72c7eca50a64adad3e1593ef06 /rag/generator/cohere.py
parent03a5a027db56932d17ae04f4054895f070d955d0 (diff)
Add reranking
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r--rag/generator/cohere.py35
1 files changed, 16 insertions, 19 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 16dfe88..3499e3b 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict
-from typing import Any, Dict, Generator, List
+from typing import Any, Generator
import cohere
from loguru import logger as log
@@ -13,8 +13,23 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+ def rerank(self, prompt: Prompt):
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=3,
+ )
+ ranking = list(
+ filter(lambda x: x.relevance_score > 0.5, response.results)
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
+ return prompt
+
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
+ if prompt.documents:
+ prompt = self.rerank(prompt)
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
@@ -27,21 +42,3 @@ class Cohere(metaclass=AbstractGenerator):
yield event.citations
elif event.event_type == "stream-end":
yield event.finish_reason
-
- def chat(
- self, prompt: Prompt, messages: List[Dict[str, str]]
- ) -> Generator[Any, Any, Any]:
- log.debug("Generating answer from cohere...")
- query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
- for event in self.client.chat_stream(
- message=query,
- documents=[asdict(d) for d in prompt.documents],
- chat_history=messages,
- prompt_truncation="AUTO",
- ):
- if event.event_type == "text-generation":
- yield event.text
- # elif event.event_type == "citation-generation":
- # yield event.citations
- elif event.event_type == "stream-end":
- yield event.finish_reason