diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 01:21:52 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 01:21:52 +0200 |
commit | 13ac875b2269756045834d7a64e7b35acb9ce0b4 (patch) | |
tree | ab05dc7ba966de66e15cc8249ec2d772a2a4d34d /rag/db/vectors.py | |
parent | 59c77c93c39755526e3d7649660780584b1c090d (diff) |
Rename dbs
Diffstat (limited to 'rag/db/vectors.py')
-rw-r--r-- | rag/db/vectors.py | 53 |
1 files changed, 0 insertions, 53 deletions
diff --git a/rag/db/vectors.py b/rag/db/vectors.py deleted file mode 100644 index 9e8becb..0000000 --- a/rag/db/vectors.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Dict, List - -from qdrant_client import QdrantClient -from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, ScoredPoint, VectorParams, PointStruct -from loguru import logger as log - - -@dataclass -class Point: - id: str - vector: List[StrictFloat] - payload: Dict[str, str] - - -class Vectors: - def __init__(self): - self.dim = int(os.environ["EMBEDDING_DIM"]) - self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] - self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.__configure() - - def __configure(self): - collections = list(map(lambda col: col.name, self.client.get_collections())) - if self.collection_name not in collections: - log.debug(f"Configuring collection {self.collection_name}...") - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), - ) - else: - log.debug(f"Collection {self.collection_name} already exists...") - - def add(self, points: List[Point]): - log.debug(f"Inserting {len(points)} vectors into the vector db...") - self.client.upload_points( - collection_name=self.collection_name, - points=[ - PointStruct(id=point.id, vector=point.vector, payload=point.payload) - for point in points - ], - parallel=4, - max_retries=3, - ) - - def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: - log.debug("Searching for vectors...") - hits = self.client.search( - collection_name=self.collection_name, query_vector=query, limit=limit - ) - return hits |