summaryrefslogtreecommitdiff
path: root/rag/generator
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-06-19 02:07:06 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-06-19 02:07:06 +0200
commitaac821b148c6c0d35b940609dc9b0ddcb053b28e (patch)
tree5c125045b2b60ead39e093327d664adf43d1d35b /rag/generator
parentf2846429310452bebbf0d07203b1e53978c439c7 (diff)
Still wip on rewrite
Diffstat (limited to 'rag/generator')
-rw-r--r--rag/generator/__init__.py6
-rw-r--r--rag/generator/abstract.py11
-rw-r--r--rag/generator/cohere.py12
-rw-r--r--rag/generator/ollama.py11
-rw-r--r--rag/generator/prompt.py10
5 files changed, 21 insertions, 29 deletions
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
index a776231..770db16 100644
--- a/rag/generator/__init__.py
+++ b/rag/generator/__init__.py
@@ -1,8 +1,8 @@
from typing import Type
-from .abstract import AbstractGenerator
-from .cohere import Cohere
-from .ollama import Ollama
+from rag.generator.abstract import AbstractGenerator
+from rag.generator.cohere import Cohere
+from rag.generator.ollama import Ollama
MODELS = ["local", "cohere"]
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 995e937..3ce997e 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,9 +1,8 @@
from abc import abstractmethod
-from typing import Any, Generator
+from typing import Any, Generator, List
-from rag.rag import Message
-
-from .prompt import Prompt
+from rag.message import Message
+from rag.retriever.vector import Document
class AbstractGenerator(type):
@@ -16,7 +15,5 @@ class AbstractGenerator(type):
return cls._instances[cls]
@abstractmethod
- def generate(
- self, prompt: Prompt, messages: List[Message]
- ) -> Generator[Any, Any, Any]:
+ def generate(self, messages: List[Message], documents: List[Document]) -> Generator[Any, Any, Any]:
pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index f30fe69..575452f 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -5,10 +5,10 @@ from typing import Any, Generator, List
import cohere
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
+from rag.retriever.vector import Document
from .abstract import AbstractGenerator
-from .prompt import Prompt
class Cohere(metaclass=AbstractGenerator):
@@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
- self, prompt: Prompt, messages: List[Message]
+ self, messages: List[Message], documents: List[Document]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
for event in self.client.chat_stream(
- message=prompt.to_str(),
- documents=[asdict(d) for d in prompt.documents],
- chat_history=[m.as_dict() for m in messages],
+ message=messages[-1].content,
+ documents=[asdict(d) for d in documents],
+ chat_history=[m.as_dict() for m in messages[:-1]],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index ff5402b..84563bb 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -4,10 +4,10 @@ from typing import Any, Generator, List
import ollama
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
+from rag.retriever.vector import Document
from .abstract import AbstractGenerator
-from .prompt import Prompt
class Ollama(metaclass=AbstractGenerator):
@@ -16,12 +16,9 @@ class Ollama(metaclass=AbstractGenerator):
log.debug(f"Using {self.model} for generator...")
def generate(
- self, prompt: Prompt, messages: List[Message]
+ self, messages: List[Message], documents: List[Document]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
- messages = messages.append(
- Message(role="user", content=prompt.to_str(), client="ollama")
- )
messages = [m.as_dict() for m in messages]
for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
- yield chunk["response"]
+ yield chunk["message"]["content"]
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 4840fdc..cedf610 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -15,7 +15,7 @@ ANSWER_INSTRUCTION = (
class Prompt:
query: str
documents: List[Document]
- generator_model: str
+ client: str
def __context(self, documents: List[Document]) -> str:
results = [
@@ -25,17 +25,15 @@ class Prompt:
return "\n".join(results)
def to_str(self) -> str:
- if self.generator_model == "cohere":
+ if self.client == "cohere":
return f"{self.query}\n\n{ANSWER_INSTRUCTION}"
else:
return (
"Context information is below.\n"
- "---------------------\n"
+ "---\n"
f"{self.__context(self.documents)}\n\n"
- "---------------------\n"
+ "---\n"
f"{ANSWER_INSTRUCTION}"
- "Do not attempt to answer the query without relevant context and do not use"
- " prior knowledge or training data!\n"
f"Query: {self.query.strip()}\n\n"
"Answer:"
)