diff options
Diffstat (limited to 'rag/generator')
| -rw-r--r-- | rag/generator/__init__.py | 15 | ||||
| -rw-r--r-- | rag/generator/abstract.py | 19 | ||||
| -rw-r--r-- | rag/generator/cohere.py | 29 | ||||
| -rw-r--r-- | rag/generator/ollama.py | 71 | ||||
| -rw-r--r-- | rag/generator/prompt.py | 14 | 
5 files changed, 148 insertions, 0 deletions
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py new file mode 100644 index 0000000..7da603c --- /dev/null +++ b/rag/generator/__init__.py @@ -0,0 +1,15 @@ +from typing import Type + +from .abstract import AbstractGenerator +from .ollama import Ollama +from .cohere import Cohere + + +def get_generator(model: str) -> Type[AbstractGenerator]: +    match model: +        case "ollama": +            return Ollama() +        case "cohere": +            return Cohere() +        case _: +            exit(1) diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py new file mode 100644 index 0000000..a53b5d8 --- /dev/null +++ b/rag/generator/abstract.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from typing import Any, Generator + +from .prompt import Prompt + + +class AbstractGenerator(ABC, type): +    _instances = {} + +    def __call__(cls, *args, **kwargs): +        if cls not in cls._instances: +            instance = super().__call__(*args, **kwargs) +            cls._instances[cls] = instance +        return cls._instances[cls] + +    @abstractmethod +    def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: +        pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py new file mode 100644 index 0000000..cf95c18 --- /dev/null +++ b/rag/generator/cohere.py @@ -0,0 +1,29 @@ +import os +from typing import Any, Generator +import cohere + +from dataclasses import asdict + +from .prompt import Prompt +from .abstract import AbstractGenerator + +from loguru import logger as log + + +class Cohere(metaclass=AbstractGenerator): +    def __init__(self) -> None: +        self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + +    def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: +        log.debug("Generating answer from cohere") +        for event in self.client.chat_stream( +            message=prompt.query, +            documents=[asdict(d) for d in prompt.documents], +            prompt_truncation="AUTO", +        ): +            if event.event_type == "text-generation": +                yield event.text +            elif event.event_type == "citation-generation": +                yield event.citations +            elif event.event_type == "stream-end": +                yield event.finish_reason diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py new file mode 100644 index 0000000..ec6a55f --- /dev/null +++ b/rag/generator/ollama.py @@ -0,0 +1,71 @@ +import os +from typing import Any, Generator, List + +import ollama +from loguru import logger as log + +from .prompt import Prompt +from .abstract import AbstractGenerator + +try: +    from rag.retriever.vector import Document +except ModuleNotFoundError: +    from retriever.vector import Document + +SYSTEM_PROMPT = ( +    "# System Preamble" +    "## Basic Rules" +    "When you answer the user's requests, you cite your sources in your answers, according to those instructions." +    "Answer the following question using the provided context.\n" +    "## Style Guide" +    "Unless the user asks for a different style of answer, you should answer " +    "in full sentences, using proper grammar and spelling." +) + + +class Ollama(metaclass=AbstractGenerator): +    def __init__(self) -> None: +        self.model = os.environ["GENERATOR_MODEL"] + +    def __context(self, documents: List[Document]) -> str: +        results = [ +            f"Document: {i}\ntitle: {doc.title}\n{doc.text}" +            for i, doc in enumerate(documents) +        ] +        return "\n".join(results) + +    def __metaprompt(self, prompt: Prompt) -> str: +        # Include sources +        metaprompt = ( +            f'Question: "{prompt.query.strip()}"\n\n' +            "Context:\n" +            "<result>\n" +            f"{self.__context(prompt.documents)}\n\n" +            "</result>\n" +            "Carefully perform the following instructions, in order, starting each " +            "with a new line.\n" +            "Firstly, Decide which of the retrieved documents are relevant to the " +            "user's last input by writing 'Relevant Documents:' followed by " +            "comma-separated list of document numbers.\n If none are relevant, you " +            "should instead write 'None'.\n" +            "Secondly, Decide which of the retrieved documents contain facts that " +            "should be cited in a good answer to the user's last input by writing " +            "'Cited Documents:' followed a comma-separated list of document numbers. " +            "If you dont want to cite any of them, you should instead write 'None'.\n" +            "Thirdly, Write 'Answer:' followed by a response to the user's last input " +            "in high quality natural english. Use the retrieved documents to help you. " +            "Do not insert any citations or grounding markup.\n" +            "Finally, Write 'Grounded answer:' followed by a response to the user's " +            "last input in high quality natural english. Use the symbols <co: doc> and " +            "</co: doc> to indicate when a fact comes from a document in the search " +            "result, e.g <co: 0>my fact</co: 0> for a fact from document 0." +        ) +        return metaprompt + +    def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: +        log.debug("Generating answer...") +        metaprompt = self.__metaprompt(prompt) +        for chunk in ollama.generate( +            model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True +        ): +            yield chunk["response"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py new file mode 100644 index 0000000..ed372c9 --- /dev/null +++ b/rag/generator/prompt.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import List + + +try: +    from rag.retriever.vector import Document +except ModuleNotFoundError: +    from retriever.vector import Document + + +@dataclass +class Prompt: +    query: str +    documents: List[Document]  |