diff options
Diffstat (limited to 'rag/db/vector.py')
-rw-r--r-- | rag/db/vector.py | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/rag/db/vector.py b/rag/db/vector.py index bbbbf32..fd2b2c2 100644 --- a/rag/db/vector.py +++ b/rag/db/vector.py @@ -5,7 +5,7 @@ from typing import Dict, List from loguru import logger as log from qdrant_client import QdrantClient from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams +from qdrant_client.models import Distance, PointStruct, VectorParams @dataclass @@ -15,11 +15,18 @@ class Point: payload: Dict[str, str] +@dataclass +class Document: + title: str + text: str + + class VectorDB: - def __init__(self): + def __init__(self, score_threshold: float = 0.6): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.score_threshold = score_threshold self.__configure() def __configure(self): @@ -47,12 +54,20 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: + def search(self, query: List[float], limit: int = 5) -> List[Document]: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, limit=limit, - score_threshold=0.6, + score_threshold=self.score_threshold, + ) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) - return hits |