summaryrefslogtreecommitdiff
path: root/rag/retriever/vector.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/vector.py')
-rw-r--r--rag/retriever/vector.py32
1 files changed, 24 insertions, 8 deletions
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index 1a484f3..b36aee8 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -21,6 +21,20 @@ class Document:
text: str
+@dataclass
+class Documents:
+ documents: List[Document]
+
+ def __len__(self):
+ return len(self.documents)
+
+ def content(self) -> List[str]:
+ return [d.text for d in self.documents]
+
+ def rerank(self, rankings: List[int]):
+ self.documents = [self.documents[r] for r in rankings]
+
+
class VectorDB:
def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
@@ -47,7 +61,7 @@ class VectorDB:
log.info(f"Deleting collection {self.collection_name}")
self.client.delete_collection(self.collection_name)
- def add(self, points: List[Point]):
+ def index(self, points: List[Point]):
log.debug(f"Inserting {len(points)} vectors into the vector db...")
self.client.upload_points(
collection_name=self.collection_name,
@@ -59,7 +73,7 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float]) -> List[Document]:
+ def search(self, query: List[float]) -> Documents:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
@@ -68,11 +82,13 @@ class VectorDB:
score_threshold=self.score_threshold,
)
log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}")
- return list(
- map(
- lambda h: Document(
- title=h.payload.get("source", ""), text=h.payload["text"]
- ),
- hits,
+ return Documents(
+ list(
+ map(
+ lambda h: Document(
+ title=h.payload.get("source", ""), text=h.payload["text"]
+ ),
+ hits,
+ )
)
)