diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-29 00:53:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-29 00:53:39 +0200 |
commit | 716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch) | |
tree | 778da9011d21051006fc206ce0978f0fc114b77b /rag/generator | |
parent | 2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff) |
Wip memory
Diffstat (limited to 'rag/generator')
-rw-r--r-- | rag/generator/cohere.py | 7 | ||||
-rw-r--r-- | rag/generator/ollama.py | 4 |
2 files changed, 7 insertions, 4 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 28a87e7..fb0cc5b 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict -from typing import Any, Generator +from typing import Any, Dict, Generator, List, Optional import cohere from loguru import logger as log @@ -13,12 +13,15 @@ class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + def generate( + self, prompt: Prompt, history: Optional[List[Dict[str, str]]] + ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, documents=[asdict(d) for d in prompt.documents], + chat_history=history, prompt_truncation="AUTO", ): if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index 52521ca..9bf551a 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -36,8 +36,8 @@ class Ollama(metaclass=AbstractGenerator): ) return metaprompt - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]: log.debug("Generating answer with ollama...") metaprompt = self.__metaprompt(prompt) - for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True): + for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True): yield chunk["response"] |