summaryrefslogtreecommitdiff
path: root/rag/generator/cohere.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r--rag/generator/cohere.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 3499e3b..049aea4 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -13,23 +13,23 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- def rerank(self, prompt: Prompt):
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=3,
- )
- ranking = list(
- filter(lambda x: x.relevance_score > 0.5, response.results)
- )
- prompt.documents = [prompt.documents[r.index] for r in ranking]
+ def rerank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=3,
+ )
+ ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
return prompt
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
- if prompt.documents:
- prompt = self.rerank(prompt)
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,