summaryrefslogtreecommitdiff
path: root/rag/retriever/vector.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/vector.py')
-rw-r--r--rag/retriever/vector.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index b72a3c1..1a484f3 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -22,11 +22,12 @@ class Document:
class VectorDB:
- def __init__(self, score_threshold: float = 0.5):
+ def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
- self.score_threshold = score_threshold
+ self.top_k = int(os.environ["RETRIEVER_TOP_K"])
+ self.score_threshold = float(os.environ["RETRIEVER_SCORE_THRESHOLD"])
self.__configure()
def __configure(self):
@@ -58,15 +59,15 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float], limit: int = 5) -> List[Document]:
+ def search(self, query: List[float]) -> List[Document]:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query,
- limit=limit,
+ limit=self.top_k,
score_threshold=self.score_threshold,
)
- log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
+ log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}")
return list(
map(
lambda h: Document(