summaryrefslogtreecommitdiff
path: root/rag/db/embeddings.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 18:31:12 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 18:31:12 +0200
commit1dfaf80c75afa84b6d03a0013eb1fd94d0257226 (patch)
tree9d20d96449141e1323b28a05fd2d3c82d9738439 /rag/db/embeddings.py
parentfcecc38d440e3643c5417f06526509bb2c0ec83e (diff)
Update from faiss to qdrant
Diffstat (limited to 'rag/db/embeddings.py')
-rw-r--r--rag/db/embeddings.py52
1 files changed, 38 insertions, 14 deletions
diff --git a/rag/db/embeddings.py b/rag/db/embeddings.py
index 1a97d16..4f61dcc 100644
--- a/rag/db/embeddings.py
+++ b/rag/db/embeddings.py
@@ -1,23 +1,47 @@
import os
-from typing import Tuple
+from dataclasses import dataclass
+from typing import Dict, List
+from uuid import uuid4
-import faiss
-import numpy as np
+from qdrant_client import QdrantClient
+from qdrant_client.http.models import StrictFloat
+from qdrant_client.models import Distance, VectorParams, PointStruct
+
+
+@dataclass
+class Point:
+ id: str
+ vector: List[StrictFloat]
+ payload: Dict[str, str]
-# TODO: inner product distance metric?
class Embeddings:
def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
- self.index = faiss.IndexFlatL2(self.dim)
- # TODO: load from file
+ self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
+ self.client = QdrantClient(url=os.environ["QDRANT_URL"])
+ self.client.delete_collection(
+ collection_name=self.collection_name,
+ )
+ self.client.create_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
+ )
- def add(self, embeddings: np.ndarray):
- # TODO: save to file
- self.index.add(embeddings)
+ def add(self, points: List[Point]):
+ print(len(points))
+ self.client.upload_points(
+ collection_name=self.collection_name,
+ points=[
+ PointStruct(id=point.id, vector=point.vector, payload=point.payload)
+ for point in points
+ ],
+ parallel=4,
+ max_retries=3,
+ )
- def search(
- self, query: np.ndarray, neighbors: int = 4
- ) -> Tuple[np.ndarray, np.ndarray]:
- score, indices = self.index.search(query, neighbors)
- return score, indices
+ def search(self, query: List[float], limit: int = 4):
+ hits = self.client.search(
+ collection_name=self.collection_name, query_vector=query, limit=limit
+ )
+ return hits