blob: f30fe6917d7f02a6037abfee238bcfd3df1fcd71 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import os
from dataclasses import asdict
from typing import Any, Generator, List
import cohere
from loguru import logger as log
from rag.rag import Message
from .abstract import AbstractGenerator
from .prompt import Prompt
class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
self, prompt: Prompt, messages: List[Message]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
for event in self.client.chat_stream(
message=prompt.to_str(),
documents=[asdict(d) for d in prompt.documents],
chat_history=[m.as_dict() for m in messages],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
yield event.text
elif event.event_type == "citation-generation":
yield event.citations
elif event.event_type == "stream-end":
yield event.finish_reason
|