summaryrefslogtreecommitdiff
path: root/rag/generator
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-06-18 01:37:32 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-06-18 01:37:32 +0200
commitb1ff0c55422d7b0af2c379679b8721014ef36926 (patch)
tree52aa88b2a8a0bba07f968c6ae24c002ce2d44226 /rag/generator
parentb8c6a78f70d84f3360461aa91864e8538569d450 (diff)
Wip rewrite
Diffstat (limited to 'rag/generator')
-rw-r--r--rag/generator/abstract.py6
-rw-r--r--rag/generator/cohere.py13
-rw-r--r--rag/generator/ollama.py36
-rw-r--r--rag/generator/prompt.py24
4 files changed, 46 insertions, 33 deletions
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 1beacfb..995e937 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,6 +1,8 @@
from abc import abstractmethod
from typing import Any, Generator
+from rag.rag import Message
+
from .prompt import Prompt
@@ -14,5 +16,7 @@ class AbstractGenerator(type):
return cls._instances[cls]
@abstractmethod
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, messages: List[Message]
+ ) -> Generator[Any, Any, Any]:
pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index fb0cc5b..f30fe69 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,12 +1,14 @@
import os
from dataclasses import asdict
-from typing import Any, Dict, Generator, List, Optional
+from typing import Any, Generator, List
import cohere
from loguru import logger as log
+from rag.rag import Message
+
from .abstract import AbstractGenerator
-from .prompt import ANSWER_INSTRUCTION, Prompt
+from .prompt import Prompt
class Cohere(metaclass=AbstractGenerator):
@@ -14,14 +16,13 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
- self, prompt: Prompt, history: Optional[List[Dict[str, str]]]
+ self, prompt: Prompt, messages: List[Message]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
- query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
- message=query,
+ message=prompt.to_str(),
documents=[asdict(d) for d in prompt.documents],
- chat_history=history,
+ chat_history=[m.as_dict() for m in messages],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 9bf551a..ff5402b 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -4,10 +4,10 @@ from typing import Any, Generator, List
import ollama
from loguru import logger as log
-from rag.retriever.vector import Document
+from rag.rag import Message
from .abstract import AbstractGenerator
-from .prompt import ANSWER_INSTRUCTION, Prompt
+from .prompt import Prompt
class Ollama(metaclass=AbstractGenerator):
@@ -15,29 +15,13 @@ class Ollama(metaclass=AbstractGenerator):
self.model = os.environ["GENERATOR_MODEL"]
log.debug(f"Using {self.model} for generator...")
- def __context(self, documents: List[Document]) -> str:
- results = [
- f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}"
- for i, doc in enumerate(documents)
- ]
- return "\n".join(results)
-
- def __metaprompt(self, prompt: Prompt) -> str:
- metaprompt = (
- "Context information is below.\n"
- "---------------------\n"
- f"{self.__context(prompt.documents)}\n\n"
- "---------------------\n"
- f"{ANSWER_INSTRUCTION}"
- "Do not attempt to answer the query without relevant context and do not use"
- " prior knowledge or training data!\n"
- f"Query: {prompt.query.strip()}\n\n"
- "Answer:"
- )
- return metaprompt
-
- def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, messages: List[Message]
+ ) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
- metaprompt = self.__metaprompt(prompt)
- for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True):
+ messages = messages.append(
+ Message(role="user", content=prompt.to_str(), client="ollama")
+ )
+ messages = [m.as_dict() for m in messages]
+ for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
yield chunk["response"]
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 6523842..4840fdc 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -15,3 +15,27 @@ ANSWER_INSTRUCTION = (
class Prompt:
query: str
documents: List[Document]
+ generator_model: str
+
+ def __context(self, documents: List[Document]) -> str:
+ results = [
+ f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}"
+ for i, doc in enumerate(documents)
+ ]
+ return "\n".join(results)
+
+ def to_str(self) -> str:
+ if self.generator_model == "cohere":
+ return f"{self.query}\n\n{ANSWER_INSTRUCTION}"
+ else:
+ return (
+ "Context information is below.\n"
+ "---------------------\n"
+ f"{self.__context(self.documents)}\n\n"
+ "---------------------\n"
+ f"{ANSWER_INSTRUCTION}"
+ "Do not attempt to answer the query without relevant context and do not use"
+ " prior knowledge or training data!\n"
+ f"Query: {self.query.strip()}\n\n"
+ "Answer:"
+ )