summaryrefslogtreecommitdiff
path: root/rag/db/embeddings.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 00:42:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 00:42:19 +0200
commit0a6b7793e982d4c59f3c6ffb947dfe28ca709cdc (patch)
tree10ae6893d2c901b04f63c5f9e15cb2678b87ce98 /rag/db/embeddings.py
parent1cf0a401054c3e3ebde60bfd73ad15e39bc531e6 (diff)
Rename
Diffstat (limited to 'rag/db/embeddings.py')
-rw-r--r--rag/db/embeddings.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/rag/db/embeddings.py b/rag/db/embeddings.py
new file mode 100644
index 0000000..f9d0735
--- /dev/null
+++ b/rag/db/embeddings.py
@@ -0,0 +1,23 @@
+import os
+from typing import Tuple
+
+import faiss
+import numpy as np
+
+
+# TODO: inner product distance metric?
+class Embeddings:
+ def __init__(self):
+ self.dim = os.environ["EMBEDDING_DIM"]
+ self.index = faiss.IndexFlatL2(self.dim)
+ # TODO: load from file
+
+ def add(self, embeddings: np.ndarray):
+ # TODO: save to file
+ self.index.add(embeddings)
+
+ def search(
+ self, query: np.ndarray, neighbors: int = 4
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ score, indices = self.index.search(query, neighbors)
+ return score, indices