From 1dfaf80c75afa84b6d03a0013eb1fd94d0257226 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 5 Apr 2024 18:31:12 +0200 Subject: Update from faiss to qdrant --- rag/db/embeddings.py | 52 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 14 deletions(-) (limited to 'rag/db/embeddings.py') diff --git a/rag/db/embeddings.py b/rag/db/embeddings.py index 1a97d16..4f61dcc 100644 --- a/rag/db/embeddings.py +++ b/rag/db/embeddings.py @@ -1,23 +1,47 @@ import os -from typing import Tuple +from dataclasses import dataclass +from typing import Dict, List +from uuid import uuid4 -import faiss -import numpy as np +from qdrant_client import QdrantClient +from qdrant_client.http.models import StrictFloat +from qdrant_client.models import Distance, VectorParams, PointStruct + + +@dataclass +class Point: + id: str + vector: List[StrictFloat] + payload: Dict[str, str] -# TODO: inner product distance metric? class Embeddings: def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) - self.index = faiss.IndexFlatL2(self.dim) - # TODO: load from file + self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] + self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.client.delete_collection( + collection_name=self.collection_name, + ) + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), + ) - def add(self, embeddings: np.ndarray): - # TODO: save to file - self.index.add(embeddings) + def add(self, points: List[Point]): + print(len(points)) + self.client.upload_points( + collection_name=self.collection_name, + points=[ + PointStruct(id=point.id, vector=point.vector, payload=point.payload) + for point in points + ], + parallel=4, + max_retries=3, + ) - def search( - self, query: np.ndarray, neighbors: int = 4 - ) -> Tuple[np.ndarray, np.ndarray]: - score, indices = self.index.search(query, neighbors) - return score, indices + def search(self, query: List[float], limit: int = 4): + hits = self.client.search( + collection_name=self.collection_name, query_vector=query, limit=limit + ) + return hits -- cgit v1.2.3-70-g09d2