summaryrefslogtreecommitdiff
path: root/rag/generator/abstract.py
blob: 71edfc457893246261cd3ccd2f6fa11deeb4476b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from abc import abstractmethod
from typing import Any, Dict, Generator, List

from .prompt import Prompt


class AbstractGenerator(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance
        return cls._instances[cls]

    @abstractmethod
    def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
        pass

    @abstractmethod
    def chat(
        self, prompt: Prompt, messages: List[Dict[str, str]]
    ) -> Generator[Any, Any, Any]:
        pass