diff options
-rw-r--r-- | rag/db/__init__.py | 0 | ||||
-rw-r--r-- | rag/db/documents.py | 85 | ||||
-rw-r--r-- | rag/db/vector.py | 23 |
3 files changed, 108 insertions, 0 deletions
diff --git a/rag/db/__init__.py b/rag/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/db/__init__.py diff --git a/rag/db/documents.py b/rag/db/documents.py new file mode 100644 index 0000000..bdbf5a4 --- /dev/null +++ b/rag/db/documents.py @@ -0,0 +1,85 @@ +import os +from typing import List +import hashlib +import psycopg +from langchain_core.documents.base import Document + +TABLES = """ +CREATE TABLE IF NOT EXISTS chunk ( + id serial PRIMARY KEY, + data text) + +CREATE TABLE IF NOT EXISTS document ( + hash text PRIMARY KEY) +""" + + +class Documents: + def __init__(self) -> None: + self.conn = psycopg.connect( + f"dbname={os.environ['RAG_DB_NAME']} user={os.environ['RAG_USER']}" + ) + self.__create_content_table() + + def close(self): + self.conn.close() + + def __create_content_table(self): + with self.conn.cursor() as cur: + cur.execute(TABLES) + self.conn.commit() + + def __hash(self, chunks: List[Document]) -> str: + document = str.encode("".join([chunk.page_content for chunk in chunks])) + return hashlib.sha256(document).hexdigest() + + def add_document(self, chunks: List[Document]) -> bool: + with self.conn.cursor() as cur: + hash = self.__hash(chunks) + cur.execute( + """ + SELECT * FROM document + WHERE + hash = %s + """, + hash, + ) + exist = cur.fetchone() + if exist is None: + cur.execute( + """ + INSERT INTO document + (hash) VALUES + (%s) + """, + hash, + ) + self.conn.commit() + return exist is not None + + + def add_chunk(self, data: str): + with self.conn.cursor() as cur: + cur.execute( + """ + INSERT INTO chunk + (data) VALUES + (%s) + """, + data, + ) + self.conn.commit() + + def get_chunk(self, id: int): + with self.conn.cursor() as cur: + cur.execute( + """ + SELECT * FROM chunk + WHERE + id = %s + """, + str(id), + ) + chunk = cur.fetchone() + self.conn.commit() + return chunk diff --git a/rag/db/vector.py b/rag/db/vector.py new file mode 100644 index 0000000..f229ba7 --- /dev/null +++ b/rag/db/vector.py @@ -0,0 +1,23 @@ +from typing import Tuple +import faiss +import numpy as np + +# TODO: read from .env +EMBEDDING_DIM = 1024 + + +# TODO: inner product distance metric? +class VectorStore: + def __init__(self): + self.index = faiss.IndexFlatL2(EMBEDDING_DIM) + # TODO: load from file + + def add(self, embeddings: np.ndarray): + # TODO: save to file + self.index.add(embeddings) + + def search( + self, query: np.ndarray, neighbors: int = 4 + ) -> Tuple[np.ndarray, np.ndarray]: + score, indices = self.index.search(query, neighbors) + return score, indices |