summaryrefslogtreecommitdiff
path: root/rag/generator
diff options
context:
space:
mode:
Diffstat (limited to 'rag/generator')
-rw-r--r--rag/generator/abstract.py8
-rw-r--r--rag/generator/cohere.py22
-rw-r--r--rag/generator/ollama.py9
3 files changed, 35 insertions, 4 deletions
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 1beacfb..71edfc4 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List
from .prompt import Prompt
@@ -16,3 +16,9 @@ class AbstractGenerator(type):
@abstractmethod
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
pass
+
+ @abstractmethod
+ def chat(
+ self, prompt: Prompt, messages: List[Dict[str, str]]
+ ) -> Generator[Any, Any, Any]:
+ pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 2ed2cf5..16dfe88 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List
import cohere
from loguru import logger as log
@@ -14,7 +14,7 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
- log.debug("Generating answer from cohere")
+ log.debug("Generating answer from cohere...")
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
@@ -27,3 +27,21 @@ class Cohere(metaclass=AbstractGenerator):
yield event.citations
elif event.event_type == "stream-end":
yield event.finish_reason
+
+ def chat(
+ self, prompt: Prompt, messages: List[Dict[str, str]]
+ ) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer from cohere...")
+ query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
+ for event in self.client.chat_stream(
+ message=query,
+ documents=[asdict(d) for d in prompt.documents],
+ chat_history=messages,
+ prompt_truncation="AUTO",
+ ):
+ if event.event_type == "text-generation":
+ yield event.text
+ # elif event.event_type == "citation-generation":
+ # yield event.citations
+ elif event.event_type == "stream-end":
+ yield event.finish_reason
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 6340235..b475dcf 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Generator, List
+from typing import Any, Dict, Generator, List
import ollama
from loguru import logger as log
@@ -38,3 +38,10 @@ class Ollama(metaclass=AbstractGenerator):
metaprompt = self.__metaprompt(prompt)
for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True):
yield chunk["response"]
+
+ def chat(self, prompt: Prompt, messages: List[Dict[str, str]]) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer with ollama...")
+ metaprompt = self.__metaprompt(prompt)
+ messages.append({"role": "user", "content": metaprompt})
+ for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
+ yield chunk["message"]["content"]