summaryrefslogtreecommitdiff
path: root/rag/db/vector.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/db/vector.py')
-rw-r--r--rag/db/vector.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/rag/db/vector.py b/rag/db/vector.py
index bbbbf32..fd2b2c2 100644
--- a/rag/db/vector.py
+++ b/rag/db/vector.py
@@ -5,7 +5,7 @@ from typing import Dict, List
from loguru import logger as log
from qdrant_client import QdrantClient
from qdrant_client.http.models import StrictFloat
-from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams
+from qdrant_client.models import Distance, PointStruct, VectorParams
@dataclass
@@ -15,11 +15,18 @@ class Point:
payload: Dict[str, str]
+@dataclass
+class Document:
+ title: str
+ text: str
+
+
class VectorDB:
- def __init__(self):
+ def __init__(self, score_threshold: float = 0.6):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
+ self.score_threshold = score_threshold
self.__configure()
def __configure(self):
@@ -47,12 +54,20 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:
+ def search(self, query: List[float], limit: int = 5) -> List[Document]:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query,
limit=limit,
- score_threshold=0.6,
+ score_threshold=self.score_threshold,
+ )
+ log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
+ return list(
+ map(
+ lambda h: Document(
+ title=h.payload.get("source", ""), text=h.payload["text"]
+ ),
+ hits,
+ )
)
- return hits