diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 | 
| commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
| tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/llm/cohere_generator.py | |
| parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) | |
Refactor
Diffstat (limited to 'rag/llm/cohere_generator.py')
| -rw-r--r-- | rag/llm/cohere_generator.py | 29 | 
1 files changed, 0 insertions, 29 deletions
| diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py deleted file mode 100644 index a6feacd..0000000 --- a/rag/llm/cohere_generator.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from typing import Any, Generator -import cohere - -from dataclasses import asdict -try: -    from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: -    from llm.ollama_generator import Prompt -from loguru import logger as log - - -class CohereGenerator: -    def __init__(self) -> None: -        self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - -    def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: -        log.debug("Generating answer from cohere") -        for event in self.client.chat_stream( -            message=prompt.query, -            documents=[asdict(d) for d in prompt.documents], -            prompt_truncation="AUTO", -        ): -            if event.event_type == "text-generation": -                yield event.text -            elif event.event_type == "citation-generation": -                yield event.citations -            elif event.event_type == "stream-end": -                yield event.finish_reason |