summaryrefslogtreecommitdiff
path: root/rag/retriever/vector.py
blob: b72a3c161b1fe389139b7669a7df46c3cc225da0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from dataclasses import dataclass
from typing import Dict, List

from loguru import logger as log
from qdrant_client import QdrantClient
from qdrant_client.http.models import StrictFloat
from qdrant_client.models import Distance, PointStruct, VectorParams


@dataclass
class Point:
    id: str
    vector: List[StrictFloat]
    payload: Dict[str, str]


@dataclass
class Document:
    title: str
    text: str


class VectorDB:
    def __init__(self, score_threshold: float = 0.5):
        self.dim = int(os.environ["EMBEDDING_DIM"])
        self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
        self.client = QdrantClient(url=os.environ["QDRANT_URL"])
        self.score_threshold = score_threshold
        self.__configure()

    def __configure(self):
        collections = list(
            map(lambda col: col.name, self.client.get_collections().collections)
        )
        if self.collection_name not in collections:
            log.debug(f"Configuring collection {self.collection_name}...")
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
            )
        else:
            log.debug(f"Collection {self.collection_name} already exists!")

    def delete_collection(self):
        log.info(f"Deleting collection {self.collection_name}")
        self.client.delete_collection(self.collection_name)

    def add(self, points: List[Point]):
        log.debug(f"Inserting {len(points)} vectors into the vector db...")
        self.client.upload_points(
            collection_name=self.collection_name,
            points=[
                PointStruct(id=point.id, vector=point.vector, payload=point.payload)
                for point in points
            ],
            parallel=4,
            max_retries=3,
        )

    def search(self, query: List[float], limit: int = 5) -> List[Document]:
        log.debug("Searching for vectors...")
        hits = self.client.search(
            collection_name=self.collection_name,
            query_vector=query,
            limit=limit,
            score_threshold=self.score_threshold,
        )
        log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
        return list(
            map(
                lambda h: Document(
                    title=h.payload.get("source", ""), text=h.payload["text"]
                ),
                hits,
            )
        )