summaryrefslogtreecommitdiff
path: root/rag/db
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
commit91ddb3672e514fa9824609ff047d7cab0c65631a (patch)
tree009fd82618588d2960b5207128e86875f73cccdc /rag/db
parentd487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff)
Refactor
Diffstat (limited to 'rag/db')
-rw-r--r--rag/db/__init__.py0
-rw-r--r--rag/db/document.py57
-rw-r--r--rag/db/vector.py73
3 files changed, 0 insertions, 130 deletions
diff --git a/rag/db/__init__.py b/rag/db/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/rag/db/__init__.py
+++ /dev/null
diff --git a/rag/db/document.py b/rag/db/document.py
deleted file mode 100644
index 54ac451..0000000
--- a/rag/db/document.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import hashlib
-import os
-
-from langchain_community.document_loaders.blob_loaders import Blob
-import psycopg
-from loguru import logger as log
-
-TABLES = """
-CREATE TABLE IF NOT EXISTS document (
- hash text PRIMARY KEY)
-"""
-
-
-class DocumentDB:
- def __init__(self) -> None:
- self.conn = psycopg.connect(
- f"dbname={os.environ['DOCUMENT_DB_NAME']} user={os.environ['DOCUMENT_DB_USER']}"
- )
- self.__configure()
-
- def close(self):
- self.conn.close()
-
- def __configure(self):
- log.debug("Creating documents table if it does not exist...")
- with self.conn.cursor() as cur:
- cur.execute(TABLES)
- self.conn.commit()
-
- def __hash(self, blob: Blob) -> str:
- log.debug("Hashing document...")
- return hashlib.sha256(blob.as_bytes()).hexdigest()
-
- def add(self, blob: Blob) -> bool:
- with self.conn.cursor() as cur:
- hash = self.__hash(blob)
- cur.execute(
- """
- SELECT * FROM document
- WHERE
- hash = %s
- """,
- (hash,),
- )
- exist = cur.fetchone()
- if exist is None:
- log.debug("Inserting document hash into documents db...")
- cur.execute(
- """
- INSERT INTO document
- (hash) VALUES
- (%s)
- """,
- (hash,),
- )
- self.conn.commit()
- return exist is None
diff --git a/rag/db/vector.py b/rag/db/vector.py
deleted file mode 100644
index fd2b2c2..0000000
--- a/rag/db/vector.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import os
-from dataclasses import dataclass
-from typing import Dict, List
-
-from loguru import logger as log
-from qdrant_client import QdrantClient
-from qdrant_client.http.models import StrictFloat
-from qdrant_client.models import Distance, PointStruct, VectorParams
-
-
-@dataclass
-class Point:
- id: str
- vector: List[StrictFloat]
- payload: Dict[str, str]
-
-
-@dataclass
-class Document:
- title: str
- text: str
-
-
-class VectorDB:
- def __init__(self, score_threshold: float = 0.6):
- self.dim = int(os.environ["EMBEDDING_DIM"])
- self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
- self.client = QdrantClient(url=os.environ["QDRANT_URL"])
- self.score_threshold = score_threshold
- self.__configure()
-
- def __configure(self):
- collections = list(
- map(lambda col: col.name, self.client.get_collections().collections)
- )
- if self.collection_name not in collections:
- log.debug(f"Configuring collection {self.collection_name}...")
- self.client.create_collection(
- collection_name=self.collection_name,
- vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
- )
- else:
- log.debug(f"Collection {self.collection_name} already exists...")
-
- def add(self, points: List[Point]):
- log.debug(f"Inserting {len(points)} vectors into the vector db...")
- self.client.upload_points(
- collection_name=self.collection_name,
- points=[
- PointStruct(id=point.id, vector=point.vector, payload=point.payload)
- for point in points
- ],
- parallel=4,
- max_retries=3,
- )
-
- def search(self, query: List[float], limit: int = 5) -> List[Document]:
- log.debug("Searching for vectors...")
- hits = self.client.search(
- collection_name=self.collection_name,
- query_vector=query,
- limit=limit,
- score_threshold=self.score_threshold,
- )
- log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
- return list(
- map(
- lambda h: Document(
- title=h.payload.get("source", ""), text=h.payload["text"]
- ),
- hits,
- )
- )