summaryrefslogtreecommitdiff
path: root/rag/retriever/rerank
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/rerank')
-rw-r--r--rag/retriever/rerank/abstract.py2
-rw-r--r--rag/retriever/rerank/cohere.py10
-rw-r--r--rag/retriever/rerank/local.py2
3 files changed, 7 insertions, 7 deletions
diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py
index f32ee77..015a60d 100644
--- a/rag/retriever/rerank/abstract.py
+++ b/rag/retriever/rerank/abstract.py
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import List
-from rag.memory import Message
+from rag.message import Message
from rag.retriever.vector import Document
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
index 33c373d..52f31a8 100644
--- a/rag/retriever/rerank/cohere.py
+++ b/rag/retriever/rerank/cohere.py
@@ -4,7 +4,7 @@ from typing import List
import cohere
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
from rag.retriever.rerank.abstract import AbstractReranker
from rag.retriever.vector import Document
@@ -35,11 +35,11 @@ class CohereReranker(metaclass=AbstractReranker):
def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
- response = self.model.rank(
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
query=query,
- documents=[m.message for m in messages],
- return_documents=False,
- top_k=self.top_k,
+ documents=[m.content for m in messages],
+ top_n=self.top_k,
)
ranking = list(
filter(
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index e727165..fd42c2c 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -2,7 +2,7 @@ import os
from typing import List
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
from rag.retriever.vector import Document
from sentence_transformers import CrossEncoder