summaryrefslogtreecommitdiff
path: root/rag/generator
diff options
context:
space:
mode:
Diffstat (limited to 'rag/generator')
-rw-r--r--rag/generator/__init__.py15
-rw-r--r--rag/generator/abstract.py19
-rw-r--r--rag/generator/cohere.py29
-rw-r--r--rag/generator/ollama.py71
-rw-r--r--rag/generator/prompt.py14
5 files changed, 148 insertions, 0 deletions
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
new file mode 100644
index 0000000..7da603c
--- /dev/null
+++ b/rag/generator/__init__.py
@@ -0,0 +1,15 @@
+from typing import Type
+
+from .abstract import AbstractGenerator
+from .ollama import Ollama
+from .cohere import Cohere
+
+
+def get_generator(model: str) -> Type[AbstractGenerator]:
+ match model:
+ case "ollama":
+ return Ollama()
+ case "cohere":
+ return Cohere()
+ case _:
+ exit(1)
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
new file mode 100644
index 0000000..a53b5d8
--- /dev/null
+++ b/rag/generator/abstract.py
@@ -0,0 +1,19 @@
+from abc import ABC, abstractmethod
+
+from typing import Any, Generator
+
+from .prompt import Prompt
+
+
+class AbstractGenerator(ABC, type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
+
+ @abstractmethod
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
new file mode 100644
index 0000000..cf95c18
--- /dev/null
+++ b/rag/generator/cohere.py
@@ -0,0 +1,29 @@
+import os
+from typing import Any, Generator
+import cohere
+
+from dataclasses import asdict
+
+from .prompt import Prompt
+from .abstract import AbstractGenerator
+
+from loguru import logger as log
+
+
+class Cohere(metaclass=AbstractGenerator):
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer from cohere")
+ for event in self.client.chat_stream(
+ message=prompt.query,
+ documents=[asdict(d) for d in prompt.documents],
+ prompt_truncation="AUTO",
+ ):
+ if event.event_type == "text-generation":
+ yield event.text
+ elif event.event_type == "citation-generation":
+ yield event.citations
+ elif event.event_type == "stream-end":
+ yield event.finish_reason
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
new file mode 100644
index 0000000..ec6a55f
--- /dev/null
+++ b/rag/generator/ollama.py
@@ -0,0 +1,71 @@
+import os
+from typing import Any, Generator, List
+
+import ollama
+from loguru import logger as log
+
+from .prompt import Prompt
+from .abstract import AbstractGenerator
+
+try:
+ from rag.retriever.vector import Document
+except ModuleNotFoundError:
+ from retriever.vector import Document
+
+SYSTEM_PROMPT = (
+ "# System Preamble"
+ "## Basic Rules"
+ "When you answer the user's requests, you cite your sources in your answers, according to those instructions."
+ "Answer the following question using the provided context.\n"
+ "## Style Guide"
+ "Unless the user asks for a different style of answer, you should answer "
+ "in full sentences, using proper grammar and spelling."
+)
+
+
+class Ollama(metaclass=AbstractGenerator):
+ def __init__(self) -> None:
+ self.model = os.environ["GENERATOR_MODEL"]
+
+ def __context(self, documents: List[Document]) -> str:
+ results = [
+ f"Document: {i}\ntitle: {doc.title}\n{doc.text}"
+ for i, doc in enumerate(documents)
+ ]
+ return "\n".join(results)
+
+ def __metaprompt(self, prompt: Prompt) -> str:
+ # Include sources
+ metaprompt = (
+ f'Question: "{prompt.query.strip()}"\n\n'
+ "Context:\n"
+ "<result>\n"
+ f"{self.__context(prompt.documents)}\n\n"
+ "</result>\n"
+ "Carefully perform the following instructions, in order, starting each "
+ "with a new line.\n"
+ "Firstly, Decide which of the retrieved documents are relevant to the "
+ "user's last input by writing 'Relevant Documents:' followed by "
+ "comma-separated list of document numbers.\n If none are relevant, you "
+ "should instead write 'None'.\n"
+ "Secondly, Decide which of the retrieved documents contain facts that "
+ "should be cited in a good answer to the user's last input by writing "
+ "'Cited Documents:' followed a comma-separated list of document numbers. "
+ "If you dont want to cite any of them, you should instead write 'None'.\n"
+ "Thirdly, Write 'Answer:' followed by a response to the user's last input "
+ "in high quality natural english. Use the retrieved documents to help you. "
+ "Do not insert any citations or grounding markup.\n"
+ "Finally, Write 'Grounded answer:' followed by a response to the user's "
+ "last input in high quality natural english. Use the symbols <co: doc> and "
+ "</co: doc> to indicate when a fact comes from a document in the search "
+ "result, e.g <co: 0>my fact</co: 0> for a fact from document 0."
+ )
+ return metaprompt
+
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer...")
+ metaprompt = self.__metaprompt(prompt)
+ for chunk in ollama.generate(
+ model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True
+ ):
+ yield chunk["response"]
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
new file mode 100644
index 0000000..ed372c9
--- /dev/null
+++ b/rag/generator/prompt.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+from typing import List
+
+
+try:
+ from rag.retriever.vector import Document
+except ModuleNotFoundError:
+ from retriever.vector import Document
+
+
+@dataclass
+class Prompt:
+ query: str
+ documents: List[Document]