summaryrefslogtreecommitdiff
path: root/rag/retriever
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 01:10:43 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 01:10:43 +0200
commit95305f59df84caded50286b1a57b6075e48725a8 (patch)
treec0f0157a99da6332a3c96462b0aba2bd02dfcb33 /rag/retriever
parent75be0914f6bd2cdeda1539f83b38fcbc854d5cfa (diff)
Rerank working
llama3 sucks at rag
Diffstat (limited to 'rag/retriever')
-rw-r--r--rag/retriever/rerank.py77
-rw-r--r--rag/retriever/retriever.py4
-rw-r--r--rag/retriever/vector.py11
3 files changed, 85 insertions, 7 deletions
diff --git a/rag/retriever/rerank.py b/rag/retriever/rerank.py
new file mode 100644
index 0000000..08a9a27
--- /dev/null
+++ b/rag/retriever/rerank.py
@@ -0,0 +1,77 @@
+import os
+from abc import abstractmethod
+from typing import Type
+
+import cohere
+from loguru import logger as log
+from sentence_transformers import CrossEncoder
+
+from rag.generator.prompt import Prompt
+
+
+class AbstractReranker(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
+
+ @abstractmethod
+ def rank(self, prompt: Prompt) -> Prompt:
+ return prompt
+
+
+class Reranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.model = CrossEncoder(os.environ["RERANK_MODEL"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [
+ prompt.documents[r.get("corpus_id", 0)] for r in ranking
+ ]
+ return prompt
+
+
+class CohereReranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
+ return prompt
+
+
+def get_reranker(model: str) -> Type[AbstractReranker]:
+ match model:
+ case "local":
+ return Reranker()
+ case "cohere":
+ return CohereReranker()
+ case _:
+ exit(1)
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index deffae5..351cfb0 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -45,7 +45,7 @@ class Retriever:
else:
log.error("Invalid input!")
- def retrieve(self, query: str, limit: int = 5) -> List[Document]:
+ def retrieve(self, query: str) -> List[Document]:
log.debug(f"Finding documents matching query: {query}")
query_emb = self.encoder.encode_query(query)
- return self.vec_db.search(query_emb, limit)
+ return self.vec_db.search(query_emb)
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index b72a3c1..1a484f3 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -22,11 +22,12 @@ class Document:
class VectorDB:
- def __init__(self, score_threshold: float = 0.5):
+ def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
- self.score_threshold = score_threshold
+ self.top_k = int(os.environ["RETRIEVER_TOP_K"])
+ self.score_threshold = float(os.environ["RETRIEVER_SCORE_THRESHOLD"])
self.__configure()
def __configure(self):
@@ -58,15 +59,15 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float], limit: int = 5) -> List[Document]:
+ def search(self, query: List[float]) -> List[Document]:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query,
- limit=limit,
+ limit=self.top_k,
score_threshold=self.score_threshold,
)
- log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
+ log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}")
return list(
map(
lambda h: Document(