blob: 770db16e4efa89f76753d7b30ea38846f1da7e84 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  | 
from typing import Type
from rag.generator.abstract import AbstractGenerator
from rag.generator.cohere import Cohere
from rag.generator.ollama import Ollama
MODELS = ["local", "cohere"]
def get_generator(model: str) -> Type[AbstractGenerator]:
    match model:
        case "local":
            return Ollama()
        case "cohere":
            return Cohere()
        case _:
            exit(1)
  |