blob: 015a60d937ee3dfd70ca9f3fd51c4a46aa9c15d8 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
from abc import abstractmethod
from typing import List
from rag.message import Message
from rag.retriever.vector import Document
class AbstractReranker(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
@abstractmethod
def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]:
pass
@abstractmethod
def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
pass
|