summaryrefslogtreecommitdiff
path: root/rag/llm/cohere_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/llm/cohere_generator.py')
-rw-r--r--rag/llm/cohere_generator.py29
1 files changed, 0 insertions, 29 deletions
diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py
deleted file mode 100644
index a6feacd..0000000
--- a/rag/llm/cohere_generator.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import os
-from typing import Any, Generator
-import cohere
-
-from dataclasses import asdict
-try:
- from rag.llm.ollama_generator import Prompt
-except ModuleNotFoundError:
- from llm.ollama_generator import Prompt
-from loguru import logger as log
-
-
-class CohereGenerator:
- def __init__(self) -> None:
- self.client = cohere.Client(os.environ["COHERE_API_KEY"])
-
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
- log.debug("Generating answer from cohere")
- for event in self.client.chat_stream(
- message=prompt.query,
- documents=[asdict(d) for d in prompt.documents],
- prompt_truncation="AUTO",
- ):
- if event.event_type == "text-generation":
- yield event.text
- elif event.event_type == "citation-generation":
- yield event.citations
- elif event.event_type == "stream-end":
- yield event.finish_reason