summaryrefslogtreecommitdiff
path: root/rag/generator/cohere.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r--rag/generator/cohere.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index fb0cc5b..f30fe69 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,12 +1,14 @@
import os
from dataclasses import asdict
-from typing import Any, Dict, Generator, List, Optional
+from typing import Any, Generator, List
import cohere
from loguru import logger as log
+from rag.rag import Message
+
from .abstract import AbstractGenerator
-from .prompt import ANSWER_INSTRUCTION, Prompt
+from .prompt import Prompt
class Cohere(metaclass=AbstractGenerator):
@@ -14,14 +16,13 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
- self, prompt: Prompt, history: Optional[List[Dict[str, str]]]
+ self, prompt: Prompt, messages: List[Message]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
- query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
- message=query,
+ message=prompt.to_str(),
documents=[asdict(d) for d in prompt.documents],
- chat_history=history,
+ chat_history=[m.as_dict() for m in messages],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":