summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank/abstract.py
blob: 015a60d937ee3dfd70ca9f3fd51c4a46aa9c15d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from abc import abstractmethod
from typing import List

from rag.message import Message
from rag.retriever.vector import Document


class AbstractReranker(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance
        return cls._instances[cls]

    @abstractmethod
    def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]:
        pass

    @abstractmethod
    def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
        pass