diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-08-05 00:37:21 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-08-05 00:37:21 +0200 |
commit | 5531d8147e52324a16c977f385715f934af5f246 (patch) | |
tree | 8688c70a4cfc1ee617c9533a401530bd15556bf9 /rag/retriever/rerank | |
parent | 5142aaaa356549ba7e7e9cdacf365326191831ac (diff) |
Fix broken stuff
Diffstat (limited to 'rag/retriever/rerank')
-rw-r--r-- | rag/retriever/rerank/local.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index fd42c2c..231d50a 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,16 +2,16 @@ import os from typing import List from loguru import logger as log -from rag.message import Message -from rag.retriever.vector import Document from sentence_transformers import CrossEncoder +from rag.message import Message from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document class Reranker(metaclass=AbstractReranker): def __init__(self) -> None: - self.model = CrossEncoder(os.environ["RERANK_MODEL"]) + self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu") self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) @@ -33,7 +33,7 @@ class Reranker(metaclass=AbstractReranker): def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: results = self.model.rank( query=query, - documents=[m.message for m in messages], + documents=[m.content for m in messages], return_documents=False, top_k=self.top_k, ) |