diff options
Diffstat (limited to 'rag/retriever/vector.py')
-rw-r--r-- | rag/retriever/vector.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index b72a3c1..1a484f3 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -22,11 +22,12 @@ class Document: class VectorDB: - def __init__(self, score_threshold: float = 0.5): + def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.score_threshold = score_threshold + self.top_k = int(os.environ["RETRIEVER_TOP_K"]) + self.score_threshold = float(os.environ["RETRIEVER_SCORE_THRESHOLD"]) self.__configure() def __configure(self): @@ -58,15 +59,15 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float], limit: int = 5) -> List[Document]: + def search(self, query: List[float]) -> List[Document]: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, - limit=limit, + limit=self.top_k, score_threshold=self.score_threshold, ) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") return list( map( lambda h: Document( |