1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
import os
from dataclasses import dataclass
from typing import Dict, List
from loguru import logger as log
from qdrant_client import QdrantClient
from qdrant_client.http.models import StrictFloat
from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams
@dataclass
class Point:
id: str
vector: List[StrictFloat]
payload: Dict[str, str]
class VectorDB:
def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
self.__configure()
def __configure(self):
collections = list(
map(lambda col: col.name, self.client.get_collections().collections)
)
if self.collection_name not in collections:
log.debug(f"Configuring collection {self.collection_name}...")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
)
else:
log.debug(f"Collection {self.collection_name} already exists...")
def add(self, points: List[Point]):
log.debug(f"Inserting {len(points)} vectors into the vector db...")
self.client.upload_points(
collection_name=self.collection_name,
points=[
PointStruct(id=point.id, vector=point.vector, payload=point.payload)
for point in points
],
parallel=4,
max_retries=3,
)
def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name, query_vector=query, limit=limit
)
return hits
|