diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-05 00:42:19 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-05 00:42:19 +0200 |
commit | 0a6b7793e982d4c59f3c6ffb947dfe28ca709cdc (patch) | |
tree | 10ae6893d2c901b04f63c5f9e15cb2678b87ce98 /rag/db/embeddings.py | |
parent | 1cf0a401054c3e3ebde60bfd73ad15e39bc531e6 (diff) |
Rename
Diffstat (limited to 'rag/db/embeddings.py')
-rw-r--r-- | rag/db/embeddings.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/rag/db/embeddings.py b/rag/db/embeddings.py new file mode 100644 index 0000000..f9d0735 --- /dev/null +++ b/rag/db/embeddings.py @@ -0,0 +1,23 @@ +import os +from typing import Tuple + +import faiss +import numpy as np + + +# TODO: inner product distance metric? +class Embeddings: + def __init__(self): + self.dim = os.environ["EMBEDDING_DIM"] + self.index = faiss.IndexFlatL2(self.dim) + # TODO: load from file + + def add(self, embeddings: np.ndarray): + # TODO: save to file + self.index.add(embeddings) + + def search( + self, query: np.ndarray, neighbors: int = 4 + ) -> Tuple[np.ndarray, np.ndarray]: + score, indices = self.index.search(query, neighbors) + return score, indices |