diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/retriever/vector.py | |
parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) |
Refactor
Diffstat (limited to 'rag/retriever/vector.py')
-rw-r--r-- | rag/retriever/vector.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py new file mode 100644 index 0000000..fd2b2c2 --- /dev/null +++ b/rag/retriever/vector.py @@ -0,0 +1,73 @@ +import os +from dataclasses import dataclass +from typing import Dict, List + +from loguru import logger as log +from qdrant_client import QdrantClient +from qdrant_client.http.models import StrictFloat +from qdrant_client.models import Distance, PointStruct, VectorParams + + +@dataclass +class Point: + id: str + vector: List[StrictFloat] + payload: Dict[str, str] + + +@dataclass +class Document: + title: str + text: str + + +class VectorDB: + def __init__(self, score_threshold: float = 0.6): + self.dim = int(os.environ["EMBEDDING_DIM"]) + self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] + self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.score_threshold = score_threshold + self.__configure() + + def __configure(self): + collections = list( + map(lambda col: col.name, self.client.get_collections().collections) + ) + if self.collection_name not in collections: + log.debug(f"Configuring collection {self.collection_name}...") + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), + ) + else: + log.debug(f"Collection {self.collection_name} already exists...") + + def add(self, points: List[Point]): + log.debug(f"Inserting {len(points)} vectors into the vector db...") + self.client.upload_points( + collection_name=self.collection_name, + points=[ + PointStruct(id=point.id, vector=point.vector, payload=point.payload) + for point in points + ], + parallel=4, + max_retries=3, + ) + + def search(self, query: List[float], limit: int = 5) -> List[Document]: + log.debug("Searching for vectors...") + hits = self.client.search( + collection_name=self.collection_name, + query_vector=query, + limit=limit, + score_threshold=self.score_threshold, + ) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) + ) |