summaryrefslogtreecommitdiff
path: root/rag/generator
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
commit716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch)
tree778da9011d21051006fc206ce0978f0fc114b77b /rag/generator
parent2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff)
Wip memory
Diffstat (limited to 'rag/generator')
-rw-r--r--rag/generator/cohere.py7
-rw-r--r--rag/generator/ollama.py4
2 files changed, 7 insertions, 4 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 28a87e7..fb0cc5b 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List, Optional
import cohere
from loguru import logger as log
@@ -13,12 +13,15 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, history: Optional[List[Dict[str, str]]]
+ ) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
documents=[asdict(d) for d in prompt.documents],
+ chat_history=history,
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 52521ca..9bf551a 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -36,8 +36,8 @@ class Ollama(metaclass=AbstractGenerator):
)
return metaprompt
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
metaprompt = self.__metaprompt(prompt)
- for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True):
+ for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True):
yield chunk["response"]