diff options
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r-- | rag/generator/cohere.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index fb0cc5b..f30fe69 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,12 +1,14 @@ import os from dataclasses import asdict -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Generator, List import cohere from loguru import logger as log +from rag.rag import Message + from .abstract import AbstractGenerator -from .prompt import ANSWER_INSTRUCTION, Prompt +from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -14,14 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, history: Optional[List[Dict[str, str]]] + self, prompt: Prompt, messages: List[Message] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") - query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( - message=query, + message=prompt.to_str(), documents=[asdict(d) for d in prompt.documents], - chat_history=history, + chat_history=[m.as_dict() for m in messages], prompt_truncation="AUTO", ): if event.event_type == "text-generation": |