diff options
Diffstat (limited to 'rag/db/vector.py')
-rw-r--r-- | rag/db/vector.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/rag/db/vector.py b/rag/db/vector.py new file mode 100644 index 0000000..f229ba7 --- /dev/null +++ b/rag/db/vector.py @@ -0,0 +1,23 @@ +from typing import Tuple +import faiss +import numpy as np + +# TODO: read from .env +EMBEDDING_DIM = 1024 + + +# TODO: inner product distance metric? +class VectorStore: + def __init__(self): + self.index = faiss.IndexFlatL2(EMBEDDING_DIM) + # TODO: load from file + + def add(self, embeddings: np.ndarray): + # TODO: save to file + self.index.add(embeddings) + + def search( + self, query: np.ndarray, neighbors: int = 4 + ) -> Tuple[np.ndarray, np.ndarray]: + score, indices = self.index.search(query, neighbors) + return score, indices |