summaryrefslogtreecommitdiff
path: root/rag/db/vectors.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 00:18:57 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 00:18:57 +0200
commita1603d4c6c29f414304fc379074eb81b5b00c5d0 (patch)
tree2ebad5348fe62148db405a4637eb49274f7c9766 /rag/db/vectors.py
parent093553777355e6d1d6c2dc9b0326909bf9859cba (diff)
Add logging in dbs
Diffstat (limited to 'rag/db/vectors.py')
-rw-r--r--rag/db/vectors.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/rag/db/vectors.py b/rag/db/vectors.py
new file mode 100644
index 0000000..9e8becb
--- /dev/null
+++ b/rag/db/vectors.py
@@ -0,0 +1,53 @@
+import os
+from dataclasses import dataclass
+from typing import Dict, List
+
+from qdrant_client import QdrantClient
+from qdrant_client.http.models import StrictFloat
+from qdrant_client.models import Distance, ScoredPoint, VectorParams, PointStruct
+from loguru import logger as log
+
+
+@dataclass
+class Point:
+ id: str
+ vector: List[StrictFloat]
+ payload: Dict[str, str]
+
+
+class Vectors:
+ def __init__(self):
+ self.dim = int(os.environ["EMBEDDING_DIM"])
+ self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
+ self.client = QdrantClient(url=os.environ["QDRANT_URL"])
+ self.__configure()
+
+ def __configure(self):
+ collections = list(map(lambda col: col.name, self.client.get_collections()))
+ if self.collection_name not in collections:
+ log.debug(f"Configuring collection {self.collection_name}...")
+ self.client.create_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
+ )
+ else:
+ log.debug(f"Collection {self.collection_name} already exists...")
+
+ def add(self, points: List[Point]):
+ log.debug(f"Inserting {len(points)} vectors into the vector db...")
+ self.client.upload_points(
+ collection_name=self.collection_name,
+ points=[
+ PointStruct(id=point.id, vector=point.vector, payload=point.payload)
+ for point in points
+ ],
+ parallel=4,
+ max_retries=3,
+ )
+
+ def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:
+ log.debug("Searching for vectors...")
+ hits = self.client.search(
+ collection_name=self.collection_name, query_vector=query, limit=limit
+ )
+ return hits