summaryrefslogtreecommitdiff
path: root/rag/db
diff options
context:
space:
mode:
Diffstat (limited to 'rag/db')
-rw-r--r--rag/db/__init__.py0
-rw-r--r--rag/db/documents.py85
-rw-r--r--rag/db/vector.py23
3 files changed, 108 insertions, 0 deletions
diff --git a/rag/db/__init__.py b/rag/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/db/__init__.py
diff --git a/rag/db/documents.py b/rag/db/documents.py
new file mode 100644
index 0000000..bdbf5a4
--- /dev/null
+++ b/rag/db/documents.py
@@ -0,0 +1,85 @@
+import os
+from typing import List
+import hashlib
+import psycopg
+from langchain_core.documents.base import Document
+
+TABLES = """
+CREATE TABLE IF NOT EXISTS chunk (
+ id serial PRIMARY KEY,
+ data text)
+
+CREATE TABLE IF NOT EXISTS document (
+ hash text PRIMARY KEY)
+"""
+
+
+class Documents:
+ def __init__(self) -> None:
+ self.conn = psycopg.connect(
+ f"dbname={os.environ['RAG_DB_NAME']} user={os.environ['RAG_USER']}"
+ )
+ self.__create_content_table()
+
+ def close(self):
+ self.conn.close()
+
+ def __create_content_table(self):
+ with self.conn.cursor() as cur:
+ cur.execute(TABLES)
+ self.conn.commit()
+
+ def __hash(self, chunks: List[Document]) -> str:
+ document = str.encode("".join([chunk.page_content for chunk in chunks]))
+ return hashlib.sha256(document).hexdigest()
+
+ def add_document(self, chunks: List[Document]) -> bool:
+ with self.conn.cursor() as cur:
+ hash = self.__hash(chunks)
+ cur.execute(
+ """
+ SELECT * FROM document
+ WHERE
+ hash = %s
+ """,
+ hash,
+ )
+ exist = cur.fetchone()
+ if exist is None:
+ cur.execute(
+ """
+ INSERT INTO document
+ (hash) VALUES
+ (%s)
+ """,
+ hash,
+ )
+ self.conn.commit()
+ return exist is not None
+
+
+ def add_chunk(self, data: str):
+ with self.conn.cursor() as cur:
+ cur.execute(
+ """
+ INSERT INTO chunk
+ (data) VALUES
+ (%s)
+ """,
+ data,
+ )
+ self.conn.commit()
+
+ def get_chunk(self, id: int):
+ with self.conn.cursor() as cur:
+ cur.execute(
+ """
+ SELECT * FROM chunk
+ WHERE
+ id = %s
+ """,
+ str(id),
+ )
+ chunk = cur.fetchone()
+ self.conn.commit()
+ return chunk
diff --git a/rag/db/vector.py b/rag/db/vector.py
new file mode 100644
index 0000000..f229ba7
--- /dev/null
+++ b/rag/db/vector.py
@@ -0,0 +1,23 @@
+from typing import Tuple
+import faiss
+import numpy as np
+
+# TODO: read from .env
+EMBEDDING_DIM = 1024
+
+
+# TODO: inner product distance metric?
+class VectorStore:
+ def __init__(self):
+ self.index = faiss.IndexFlatL2(EMBEDDING_DIM)
+ # TODO: load from file
+
+ def add(self, embeddings: np.ndarray):
+ # TODO: save to file
+ self.index.add(embeddings)
+
+ def search(
+ self, query: np.ndarray, neighbors: int = 4
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ score, indices = self.index.search(query, neighbors)
+ return score, indices