From aac821b148c6c0d35b940609dc9b0ddcb053b28e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 19 Jun 2024 02:07:06 +0200 Subject: Still wip on rewrite --- rag/retriever/rerank/abstract.py | 2 +- rag/retriever/rerank/cohere.py | 10 +++++----- rag/retriever/rerank/local.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'rag/retriever/rerank') diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index f32ee77..015a60d 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import List -from rag.memory import Message +from rag.message import Message from rag.retriever.vector import Document diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 33c373d..52f31a8 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -4,7 +4,7 @@ from typing import List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message from rag.retriever.rerank.abstract import AbstractReranker from rag.retriever.vector import Document @@ -35,11 +35,11 @@ class CohereReranker(metaclass=AbstractReranker): def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: - response = self.model.rank( + response = self.client.rerank( + model="rerank-english-v3.0", query=query, - documents=[m.message for m in messages], - return_documents=False, - top_k=self.top_k, + documents=[m.content for m in messages], + top_n=self.top_k, ) ranking = list( filter( diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index e727165..fd42c2c 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,7 +2,7 @@ import os from typing import List from loguru import logger as log -from rag.rag import Message +from rag.message import Message from rag.retriever.vector import Document from sentence_transformers import CrossEncoder -- cgit v1.2.3-70-g09d2