diff options
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r-- | rag/generator/cohere.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index f30fe69..575452f 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -5,10 +5,10 @@ from typing import Any, Generator, List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document from .abstract import AbstractGenerator -from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, messages: List[Message] + self, messages: List[Message], documents: List[Document] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") for event in self.client.chat_stream( - message=prompt.to_str(), - documents=[asdict(d) for d in prompt.documents], - chat_history=[m.as_dict() for m in messages], + message=messages[-1].content, + documents=[asdict(d) for d in documents], + chat_history=[m.as_dict() for m in messages[:-1]], prompt_truncation="AUTO", ): if event.event_type == "text-generation": |