summaryrefslogtreecommitdiff
path: root/rag/generator/__init__.py
blob: ba23ffc4bf9e83c873553dc69658ff97dae96f57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Type

from .abstract import AbstractGenerator
from .cohere import Cohere
from .ollama import Ollama

MODELS = ["ollama", "cohere"]

def get_generator(model: str) -> Type[AbstractGenerator]:
    match model:
        case "ollama":
            return Ollama()
        case "cohere":
            return Cohere()
        case _:
            exit(1)