From d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 8 Apr 2024 22:28:47 +0200 Subject: Wip refactor --- rag/db/vector.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) (limited to 'rag/db/vector.py') diff --git a/rag/db/vector.py b/rag/db/vector.py index bbbbf32..fd2b2c2 100644 --- a/rag/db/vector.py +++ b/rag/db/vector.py @@ -5,7 +5,7 @@ from typing import Dict, List from loguru import logger as log from qdrant_client import QdrantClient from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams +from qdrant_client.models import Distance, PointStruct, VectorParams @dataclass @@ -15,11 +15,18 @@ class Point: payload: Dict[str, str] +@dataclass +class Document: + title: str + text: str + + class VectorDB: - def __init__(self): + def __init__(self, score_threshold: float = 0.6): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.score_threshold = score_threshold self.__configure() def __configure(self): @@ -47,12 +54,20 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: + def search(self, query: List[float], limit: int = 5) -> List[Document]: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, limit=limit, - score_threshold=0.6, + score_threshold=self.score_threshold, + ) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) - return hits -- cgit v1.2.3-70-g09d2