diff options
Diffstat (limited to 'rag/retriever/vector.py')
-rw-r--r-- | rag/retriever/vector.py | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index 1a484f3..b36aee8 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -21,6 +21,20 @@ class Document: text: str +@dataclass +class Documents: + documents: List[Document] + + def __len__(self): + return len(self.documents) + + def content(self) -> List[str]: + return [d.text for d in self.documents] + + def rerank(self, rankings: List[int]): + self.documents = [self.documents[r] for r in rankings] + + class VectorDB: def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) @@ -47,7 +61,7 @@ class VectorDB: log.info(f"Deleting collection {self.collection_name}") self.client.delete_collection(self.collection_name) - def add(self, points: List[Point]): + def index(self, points: List[Point]): log.debug(f"Inserting {len(points)} vectors into the vector db...") self.client.upload_points( collection_name=self.collection_name, @@ -59,7 +73,7 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float]) -> List[Document]: + def search(self, query: List[float]) -> Documents: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, @@ -68,11 +82,13 @@ class VectorDB: score_threshold=self.score_threshold, ) log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") - return list( - map( - lambda h: Document( - title=h.payload.get("source", ""), text=h.payload["text"] - ), - hits, + return Documents( + list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) ) |