summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/__init__.py
blob: 16b2fac2f4bf1352c910a2812e5a97b0fdfdd7d0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import Type

from rag.retriever.rerank.abstract import AbstractReranker
from rag.retriever.rerank.cohere import CohereReranker
from rag.retriever.rerank.local import Reranker


def get_reranker(model: str) -> Type[AbstractReranker]:
    match model:
        case "local":
            return Reranker()
        case "cohere":
            return CohereReranker()
        case _:
            exit(1)