From 28a1f5d4eddab6eb7c9ca77346c6fa9608856dd5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 25 Aug 2025 00:06:19 +0200 Subject: Broken state --- rag/retriever/vector.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) (limited to 'rag/retriever/vector.py') diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index 1a484f3..b36aee8 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -21,6 +21,20 @@ class Document: text: str +@dataclass +class Documents: + documents: List[Document] + + def __len__(self): + return len(self.documents) + + def content(self) -> List[str]: + return [d.text for d in self.documents] + + def rerank(self, rankings: List[int]): + self.documents = [self.documents[r] for r in rankings] + + class VectorDB: def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) @@ -47,7 +61,7 @@ class VectorDB: log.info(f"Deleting collection {self.collection_name}") self.client.delete_collection(self.collection_name) - def add(self, points: List[Point]): + def index(self, points: List[Point]): log.debug(f"Inserting {len(points)} vectors into the vector db...") self.client.upload_points( collection_name=self.collection_name, @@ -59,7 +73,7 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float]) -> List[Document]: + def search(self, query: List[float]) -> Documents: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, @@ -68,11 +82,13 @@ class VectorDB: score_threshold=self.score_threshold, ) log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") - return list( - map( - lambda h: Document( - title=h.payload.get("source", ""), text=h.payload["text"] - ), - hits, + return Documents( + list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) ) -- cgit v1.2.3-70-g09d2