summaryrefslogtreecommitdiff
path: root/rag/generator/cohere.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
commit91ddb3672e514fa9824609ff047d7cab0c65631a (patch)
tree009fd82618588d2960b5207128e86875f73cccdc /rag/generator/cohere.py
parentd487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff)
Refactor
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r--rag/generator/cohere.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
new file mode 100644
index 0000000..cf95c18
--- /dev/null
+++ b/rag/generator/cohere.py
@@ -0,0 +1,29 @@
+import os
+from typing import Any, Generator
+import cohere
+
+from dataclasses import asdict
+
+from .prompt import Prompt
+from .abstract import AbstractGenerator
+
+from loguru import logger as log
+
+
+class Cohere(metaclass=AbstractGenerator):
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer from cohere")
+ for event in self.client.chat_stream(
+ message=prompt.query,
+ documents=[asdict(d) for d in prompt.documents],
+ prompt_truncation="AUTO",
+ ):
+ if event.event_type == "text-generation":
+ yield event.text
+ elif event.event_type == "citation-generation":
+ yield event.citations
+ elif event.event_type == "stream-end":
+ yield event.finish_reason