From 13ac875b2269756045834d7a64e7b35acb9ce0b4 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 6 Apr 2024 01:21:52 +0200 Subject: Rename dbs --- rag/db/document.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++ rag/db/documents.py | 59 ----------------------------------------------------- rag/db/vector.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ rag/db/vectors.py | 53 ----------------------------------------------- 4 files changed, 114 insertions(+), 112 deletions(-) create mode 100644 rag/db/document.py delete mode 100644 rag/db/documents.py create mode 100644 rag/db/vector.py delete mode 100644 rag/db/vectors.py diff --git a/rag/db/document.py b/rag/db/document.py new file mode 100644 index 0000000..8e4d208 --- /dev/null +++ b/rag/db/document.py @@ -0,0 +1,59 @@ +import hashlib +import os +from typing import List + +import psycopg +from langchain_core.documents.base import Document +from loguru import logger as log + +TABLES = """ +CREATE TABLE IF NOT EXISTS document ( + hash text PRIMARY KEY) +""" + + +class DocumentDB: + def __init__(self) -> None: + self.conn = psycopg.connect( + f"dbname={os.environ['RAG_DB_NAME']} user={os.environ['RAG_DB_USER']}" + ) + self.__configure() + + def close(self): + self.conn.close() + + def __configure(self): + log.debug("Creating documents table if it does not exist...") + with self.conn.cursor() as cur: + cur.execute(TABLES) + self.conn.commit() + + def __hash(self, chunks: List[Document]) -> str: + log.debug("Generating sha256 hash for pdf document") + document = str.encode("".join([chunk.page_content for chunk in chunks])) + return hashlib.sha256(document).hexdigest() + + def add_document(self, chunks: List[Document]) -> bool: + log.debug("Inserting document hash into documents db...") + with self.conn.cursor() as cur: + hash = self.__hash(chunks) + cur.execute( + """ + SELECT * FROM document + WHERE + hash = %s + """, + (hash,), + ) + exist = cur.fetchone() + if exist is None: + cur.execute( + """ + INSERT INTO document + (hash) VALUES + (%s) + """, + (hash,), + ) + self.conn.commit() + return exist is not None diff --git a/rag/db/documents.py b/rag/db/documents.py deleted file mode 100644 index 6f83b1f..0000000 --- a/rag/db/documents.py +++ /dev/null @@ -1,59 +0,0 @@ -import hashlib -import os -from typing import List - -import psycopg -from langchain_core.documents.base import Document -from loguru import logger as log - -TABLES = """ -CREATE TABLE IF NOT EXISTS document ( - hash text PRIMARY KEY) -""" - - -class Documents: - def __init__(self) -> None: - self.conn = psycopg.connect( - f"dbname={os.environ['RAG_DB_NAME']} user={os.environ['RAG_DB_USER']}" - ) - self.__configure() - - def close(self): - self.conn.close() - - def __configure(self): - log.debug("Creating documents table if it does not exist...") - with self.conn.cursor() as cur: - cur.execute(TABLES) - self.conn.commit() - - def __hash(self, chunks: List[Document]) -> str: - log.debug("Generating sha256 hash for pdf document") - document = str.encode("".join([chunk.page_content for chunk in chunks])) - return hashlib.sha256(document).hexdigest() - - def add_document(self, chunks: List[Document]) -> bool: - log.debug("Inserting document hash into documents db...") - with self.conn.cursor() as cur: - hash = self.__hash(chunks) - cur.execute( - """ - SELECT * FROM document - WHERE - hash = %s - """, - (hash,), - ) - exist = cur.fetchone() - if exist is None: - cur.execute( - """ - INSERT INTO document - (hash) VALUES - (%s) - """, - (hash,), - ) - self.conn.commit() - return exist is not None diff --git a/rag/db/vector.py b/rag/db/vector.py new file mode 100644 index 0000000..4aa62cc --- /dev/null +++ b/rag/db/vector.py @@ -0,0 +1,55 @@ +import os +from dataclasses import dataclass +from typing import Dict, List + +from loguru import logger as log +from qdrant_client import QdrantClient +from qdrant_client.http.models import StrictFloat +from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams + + +@dataclass +class Point: + id: str + vector: List[StrictFloat] + payload: Dict[str, str] + + +class VectorDB: + def __init__(self): + self.dim = int(os.environ["EMBEDDING_DIM"]) + self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] + self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.__configure() + + def __configure(self): + collections = list( + map(lambda col: col.name, self.client.get_collections().collections) + ) + if self.collection_name not in collections: + log.debug(f"Configuring collection {self.collection_name}...") + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), + ) + else: + log.debug(f"Collection {self.collection_name} already exists...") + + def add(self, points: List[Point]): + log.debug(f"Inserting {len(points)} vectors into the vector db...") + self.client.upload_points( + collection_name=self.collection_name, + points=[ + PointStruct(id=point.id, vector=point.vector, payload=point.payload) + for point in points + ], + parallel=4, + max_retries=3, + ) + + def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: + log.debug("Searching for vectors...") + hits = self.client.search( + collection_name=self.collection_name, query_vector=query, limit=limit + ) + return hits diff --git a/rag/db/vectors.py b/rag/db/vectors.py deleted file mode 100644 index 9e8becb..0000000 --- a/rag/db/vectors.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Dict, List - -from qdrant_client import QdrantClient -from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, ScoredPoint, VectorParams, PointStruct -from loguru import logger as log - - -@dataclass -class Point: - id: str - vector: List[StrictFloat] - payload: Dict[str, str] - - -class Vectors: - def __init__(self): - self.dim = int(os.environ["EMBEDDING_DIM"]) - self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] - self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.__configure() - - def __configure(self): - collections = list(map(lambda col: col.name, self.client.get_collections())) - if self.collection_name not in collections: - log.debug(f"Configuring collection {self.collection_name}...") - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), - ) - else: - log.debug(f"Collection {self.collection_name} already exists...") - - def add(self, points: List[Point]): - log.debug(f"Inserting {len(points)} vectors into the vector db...") - self.client.upload_points( - collection_name=self.collection_name, - points=[ - PointStruct(id=point.id, vector=point.vector, payload=point.payload) - for point in points - ], - parallel=4, - max_retries=3, - ) - - def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: - log.debug("Searching for vectors...") - hits = self.client.search( - collection_name=self.collection_name, query_vector=query, limit=limit - ) - return hits -- cgit v1.2.3-70-g09d2