diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-06-19 02:07:06 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-06-19 02:07:06 +0200 |
commit | aac821b148c6c0d35b940609dc9b0ddcb053b28e (patch) | |
tree | 5c125045b2b60ead39e093327d664adf43d1d35b /rag/retriever | |
parent | f2846429310452bebbf0d07203b1e53978c439c7 (diff) |
Still wip on rewrite
Diffstat (limited to 'rag/retriever')
-rw-r--r-- | rag/retriever/rerank/abstract.py | 2 | ||||
-rw-r--r-- | rag/retriever/rerank/cohere.py | 10 | ||||
-rw-r--r-- | rag/retriever/rerank/local.py | 2 |
3 files changed, 7 insertions, 7 deletions
diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index f32ee77..015a60d 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import List -from rag.memory import Message +from rag.message import Message from rag.retriever.vector import Document diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 33c373d..52f31a8 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -4,7 +4,7 @@ from typing import List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message from rag.retriever.rerank.abstract import AbstractReranker from rag.retriever.vector import Document @@ -35,11 +35,11 @@ class CohereReranker(metaclass=AbstractReranker): def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: - response = self.model.rank( + response = self.client.rerank( + model="rerank-english-v3.0", query=query, - documents=[m.message for m in messages], - return_documents=False, - top_k=self.top_k, + documents=[m.content for m in messages], + top_n=self.top_k, ) ranking = list( filter( diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index e727165..fd42c2c 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,7 +2,7 @@ import os from typing import List from loguru import logger as log -from rag.rag import Message +from rag.message import Message from rag.retriever.vector import Document from sentence_transformers import CrossEncoder |