summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
Diffstat (limited to 'rag')
-rw-r--r--rag/db/documents.py8
-rw-r--r--rag/db/vectors.py (renamed from rag/db/embeddings.py)30
2 files changed, 24 insertions, 14 deletions
diff --git a/rag/db/documents.py b/rag/db/documents.py
index 7d088da..6f83b1f 100644
--- a/rag/db/documents.py
+++ b/rag/db/documents.py
@@ -4,6 +4,7 @@ from typing import List
import psycopg
from langchain_core.documents.base import Document
+from loguru import logger as log
TABLES = """
CREATE TABLE IF NOT EXISTS document (
@@ -16,21 +17,24 @@ class Documents:
self.conn = psycopg.connect(
f"dbname={os.environ['RAG_DB_NAME']} user={os.environ['RAG_DB_USER']}"
)
- self.__create_content_table()
+ self.__configure()
def close(self):
self.conn.close()
- def __create_content_table(self):
+ def __configure(self):
+ log.debug("Creating documents table if it does not exist...")
with self.conn.cursor() as cur:
cur.execute(TABLES)
self.conn.commit()
def __hash(self, chunks: List[Document]) -> str:
+ log.debug("Generating sha256 hash for pdf document")
document = str.encode("".join([chunk.page_content for chunk in chunks]))
return hashlib.sha256(document).hexdigest()
def add_document(self, chunks: List[Document]) -> bool:
+ log.debug("Inserting document hash into documents db...")
with self.conn.cursor() as cur:
hash = self.__hash(chunks)
cur.execute(
diff --git a/rag/db/embeddings.py b/rag/db/vectors.py
index 4f61dcc..9e8becb 100644
--- a/rag/db/embeddings.py
+++ b/rag/db/vectors.py
@@ -1,11 +1,11 @@
import os
from dataclasses import dataclass
from typing import Dict, List
-from uuid import uuid4
from qdrant_client import QdrantClient
from qdrant_client.http.models import StrictFloat
-from qdrant_client.models import Distance, VectorParams, PointStruct
+from qdrant_client.models import Distance, ScoredPoint, VectorParams, PointStruct
+from loguru import logger as log
@dataclass
@@ -15,21 +15,26 @@ class Point:
payload: Dict[str, str]
-class Embeddings:
+class Vectors:
def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
- self.client.delete_collection(
- collection_name=self.collection_name,
- )
- self.client.create_collection(
- collection_name=self.collection_name,
- vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
- )
+ self.__configure()
+
+ def __configure(self):
+ collections = list(map(lambda col: col.name, self.client.get_collections()))
+ if self.collection_name not in collections:
+ log.debug(f"Configuring collection {self.collection_name}...")
+ self.client.create_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
+ )
+ else:
+ log.debug(f"Collection {self.collection_name} already exists...")
def add(self, points: List[Point]):
- print(len(points))
+ log.debug(f"Inserting {len(points)} vectors into the vector db...")
self.client.upload_points(
collection_name=self.collection_name,
points=[
@@ -40,7 +45,8 @@ class Embeddings:
max_retries=3,
)
- def search(self, query: List[float], limit: int = 4):
+ def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:
+ log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name, query_vector=query, limit=limit
)