blob: 575452f8aa6172324469f08de195c173c9592c04 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import os
from dataclasses import asdict
from typing import Any, Generator, List
import cohere
from loguru import logger as log
from rag.message import Message
from rag.retriever.vector import Document
from .abstract import AbstractGenerator
class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
self, messages: List[Message], documents: List[Document]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
for event in self.client.chat_stream(
message=messages[-1].content,
documents=[asdict(d) for d in documents],
chat_history=[m.as_dict() for m in messages[:-1]],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
yield event.text
elif event.event_type == "citation-generation":
yield event.citations
elif event.event_type == "stream-end":
yield event.finish_reason
|