diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/generator/cohere.py | |
parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) |
Refactor
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r-- | rag/generator/cohere.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py new file mode 100644 index 0000000..cf95c18 --- /dev/null +++ b/rag/generator/cohere.py @@ -0,0 +1,29 @@ +import os +from typing import Any, Generator +import cohere + +from dataclasses import asdict + +from .prompt import Prompt +from .abstract import AbstractGenerator + +from loguru import logger as log + + +class Cohere(metaclass=AbstractGenerator): + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + log.debug("Generating answer from cohere") + for event in self.client.chat_stream( + message=prompt.query, + documents=[asdict(d) for d in prompt.documents], + prompt_truncation="AUTO", + ): + if event.event_type == "text-generation": + yield event.text + elif event.event_type == "citation-generation": + yield event.citations + elif event.event_type == "stream-end": + yield event.finish_reason |