diff options
Diffstat (limited to 'rag/db')
-rw-r--r-- | rag/db/documents.py | 8 | ||||
-rw-r--r-- | rag/db/vectors.py (renamed from rag/db/embeddings.py) | 30 |
2 files changed, 24 insertions, 14 deletions
diff --git a/rag/db/documents.py b/rag/db/documents.py index 7d088da..6f83b1f 100644 --- a/rag/db/documents.py +++ b/rag/db/documents.py @@ -4,6 +4,7 @@ from typing import List import psycopg from langchain_core.documents.base import Document +from loguru import logger as log TABLES = """ CREATE TABLE IF NOT EXISTS document ( @@ -16,21 +17,24 @@ class Documents: self.conn = psycopg.connect( f"dbname={os.environ['RAG_DB_NAME']} user={os.environ['RAG_DB_USER']}" ) - self.__create_content_table() + self.__configure() def close(self): self.conn.close() - def __create_content_table(self): + def __configure(self): + log.debug("Creating documents table if it does not exist...") with self.conn.cursor() as cur: cur.execute(TABLES) self.conn.commit() def __hash(self, chunks: List[Document]) -> str: + log.debug("Generating sha256 hash for pdf document") document = str.encode("".join([chunk.page_content for chunk in chunks])) return hashlib.sha256(document).hexdigest() def add_document(self, chunks: List[Document]) -> bool: + log.debug("Inserting document hash into documents db...") with self.conn.cursor() as cur: hash = self.__hash(chunks) cur.execute( diff --git a/rag/db/embeddings.py b/rag/db/vectors.py index 4f61dcc..9e8becb 100644 --- a/rag/db/embeddings.py +++ b/rag/db/vectors.py @@ -1,11 +1,11 @@ import os from dataclasses import dataclass from typing import Dict, List -from uuid import uuid4 from qdrant_client import QdrantClient from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, VectorParams, PointStruct +from qdrant_client.models import Distance, ScoredPoint, VectorParams, PointStruct +from loguru import logger as log @dataclass @@ -15,21 +15,26 @@ class Point: payload: Dict[str, str] -class Embeddings: +class Vectors: def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.client.delete_collection( - collection_name=self.collection_name, - ) - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), - ) + self.__configure() + + def __configure(self): + collections = list(map(lambda col: col.name, self.client.get_collections())) + if self.collection_name not in collections: + log.debug(f"Configuring collection {self.collection_name}...") + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), + ) + else: + log.debug(f"Collection {self.collection_name} already exists...") def add(self, points: List[Point]): - print(len(points)) + log.debug(f"Inserting {len(points)} vectors into the vector db...") self.client.upload_points( collection_name=self.collection_name, points=[ @@ -40,7 +45,8 @@ class Embeddings: max_retries=3, ) - def search(self, query: List[float], limit: int = 4): + def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: + log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, limit=limit ) |