diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-13 02:26:01 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-13 02:26:01 +0200 |
commit | 72d1caf92115d90ae789de1cffed29406f2a0a39 (patch) | |
tree | ca4f3b755dccdd94894f18e7ce599cfd7eb28e58 /rag/generator | |
parent | 36722903391ec42d5458112bc0549eb843548d90 (diff) |
Wip chat ui
Diffstat (limited to 'rag/generator')
-rw-r--r-- | rag/generator/abstract.py | 8 | ||||
-rw-r--r-- | rag/generator/cohere.py | 22 | ||||
-rw-r--r-- | rag/generator/ollama.py | 9 |
3 files changed, 35 insertions, 4 deletions
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 1beacfb..71edfc4 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Generator +from typing import Any, Dict, Generator, List from .prompt import Prompt @@ -16,3 +16,9 @@ class AbstractGenerator(type): @abstractmethod def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: pass + + @abstractmethod + def chat( + self, prompt: Prompt, messages: List[Dict[str, str]] + ) -> Generator[Any, Any, Any]: + pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 2ed2cf5..16dfe88 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict -from typing import Any, Generator +from typing import Any, Dict, Generator, List import cohere from loguru import logger as log @@ -14,7 +14,7 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - log.debug("Generating answer from cohere") + log.debug("Generating answer from cohere...") query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, @@ -27,3 +27,21 @@ class Cohere(metaclass=AbstractGenerator): yield event.citations elif event.event_type == "stream-end": yield event.finish_reason + + def chat( + self, prompt: Prompt, messages: List[Dict[str, str]] + ) -> Generator[Any, Any, Any]: + log.debug("Generating answer from cohere...") + query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" + for event in self.client.chat_stream( + message=query, + documents=[asdict(d) for d in prompt.documents], + chat_history=messages, + prompt_truncation="AUTO", + ): + if event.event_type == "text-generation": + yield event.text + # elif event.event_type == "citation-generation": + # yield event.citations + elif event.event_type == "stream-end": + yield event.finish_reason diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index 6340235..b475dcf 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -1,5 +1,5 @@ import os -from typing import Any, Generator, List +from typing import Any, Dict, Generator, List import ollama from loguru import logger as log @@ -38,3 +38,10 @@ class Ollama(metaclass=AbstractGenerator): metaprompt = self.__metaprompt(prompt) for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True): yield chunk["response"] + + def chat(self, prompt: Prompt, messages: List[Dict[str, str]]) -> Generator[Any, Any, Any]: + log.debug("Generating answer with ollama...") + metaprompt = self.__metaprompt(prompt) + messages.append({"role": "user", "content": metaprompt}) + for chunk in ollama.chat(model=self.model, messages=messages, stream=True): + yield chunk["message"]["content"] |