diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-24 01:10:43 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-24 01:10:43 +0200 |
commit | 95305f59df84caded50286b1a57b6075e48725a8 (patch) | |
tree | c0f0157a99da6332a3c96462b0aba2bd02dfcb33 /rag/retriever | |
parent | 75be0914f6bd2cdeda1539f83b38fcbc854d5cfa (diff) |
Rerank working
llama3 sucks at rag
Diffstat (limited to 'rag/retriever')
-rw-r--r-- | rag/retriever/rerank.py | 77 | ||||
-rw-r--r-- | rag/retriever/retriever.py | 4 | ||||
-rw-r--r-- | rag/retriever/vector.py | 11 |
3 files changed, 85 insertions, 7 deletions
diff --git a/rag/retriever/rerank.py b/rag/retriever/rerank.py new file mode 100644 index 0000000..08a9a27 --- /dev/null +++ b/rag/retriever/rerank.py @@ -0,0 +1,77 @@ +import os +from abc import abstractmethod +from typing import Type + +import cohere +from loguru import logger as log +from sentence_transformers import CrossEncoder + +from rag.generator.prompt import Prompt + + +class AbstractReranker(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + @abstractmethod + def rank(self, prompt: Prompt) -> Prompt: + return prompt + + +class Reranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.model = CrossEncoder(os.environ["RERANK_MODEL"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + results = self.model.rank( + query=prompt.query, + documents=[d.text for d in prompt.documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [ + prompt.documents[r.get("corpus_id", 0)] for r in ranking + ] + return prompt + + +class CohereReranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=self.top_k, + ) + ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] + return prompt + + +def get_reranker(model: str) -> Type[AbstractReranker]: + match model: + case "local": + return Reranker() + case "cohere": + return CohereReranker() + case _: + exit(1) diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index deffae5..351cfb0 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -45,7 +45,7 @@ class Retriever: else: log.error("Invalid input!") - def retrieve(self, query: str, limit: int = 5) -> List[Document]: + def retrieve(self, query: str) -> List[Document]: log.debug(f"Finding documents matching query: {query}") query_emb = self.encoder.encode_query(query) - return self.vec_db.search(query_emb, limit) + return self.vec_db.search(query_emb) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index b72a3c1..1a484f3 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -22,11 +22,12 @@ class Document: class VectorDB: - def __init__(self, score_threshold: float = 0.5): + def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.score_threshold = score_threshold + self.top_k = int(os.environ["RETRIEVER_TOP_K"]) + self.score_threshold = float(os.environ["RETRIEVER_SCORE_THRESHOLD"]) self.__configure() def __configure(self): @@ -58,15 +59,15 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float], limit: int = 5) -> List[Document]: + def search(self, query: List[float]) -> List[Document]: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, - limit=limit, + limit=self.top_k, score_threshold=self.score_threshold, ) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") return list( map( lambda h: Document( |