summaryrefslogtreecommitdiff
path: root/rag/generator/cohere.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-11 08:59:41 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-11 08:59:41 +0200
commit5b7d1cb49bd473c7dbcf6e89f7d1b6fc8be1f5b1 (patch)
tree41a19d0ff59cb00eca6e30511bbe2d54996b4b2a /rag/generator/cohere.py
parent98f8d1d535c30d8c1ca6c7b52e634a99b88acf10 (diff)
Improve prompt
Diffstat (limited to 'rag/generator/cohere.py')
-rw-r--r--rag/generator/cohere.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 7028b21..2ed2cf5 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -6,7 +6,7 @@ import cohere
from loguru import logger as log
from .abstract import AbstractGenerator
-from .prompt import Prompt
+from .prompt import ANSWER_INSTRUCTION, Prompt
class Cohere(metaclass=AbstractGenerator):
@@ -15,8 +15,9 @@ class Cohere(metaclass=AbstractGenerator):
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere")
+ query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
- message=prompt.query,
+ message=query,
documents=[asdict(d) for d in prompt.documents],
prompt_truncation="AUTO",
):