import os from typing import Any, Generator import cohere from dataclasses import asdict from .prompt import Prompt from .abstract import AbstractGenerator from loguru import logger as log class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere") for event in self.client.chat_stream( message=prompt.query, documents=[asdict(d) for d in prompt.documents], prompt_truncation="AUTO", ): if event.event_type == "text-generation": yield event.text elif event.event_type == "citation-generation": yield event.citations elif event.event_type == "stream-end": yield event.finish_reason