diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/db | |
parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) |
Refactor
Diffstat (limited to 'rag/db')
-rw-r--r-- | rag/db/__init__.py | 0 | ||||
-rw-r--r-- | rag/db/document.py | 57 | ||||
-rw-r--r-- | rag/db/vector.py | 73 |
3 files changed, 0 insertions, 130 deletions
diff --git a/rag/db/__init__.py b/rag/db/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/rag/db/__init__.py +++ /dev/null diff --git a/rag/db/document.py b/rag/db/document.py deleted file mode 100644 index 54ac451..0000000 --- a/rag/db/document.py +++ /dev/null @@ -1,57 +0,0 @@ -import hashlib -import os - -from langchain_community.document_loaders.blob_loaders import Blob -import psycopg -from loguru import logger as log - -TABLES = """ -CREATE TABLE IF NOT EXISTS document ( - hash text PRIMARY KEY) -""" - - -class DocumentDB: - def __init__(self) -> None: - self.conn = psycopg.connect( - f"dbname={os.environ['DOCUMENT_DB_NAME']} user={os.environ['DOCUMENT_DB_USER']}" - ) - self.__configure() - - def close(self): - self.conn.close() - - def __configure(self): - log.debug("Creating documents table if it does not exist...") - with self.conn.cursor() as cur: - cur.execute(TABLES) - self.conn.commit() - - def __hash(self, blob: Blob) -> str: - log.debug("Hashing document...") - return hashlib.sha256(blob.as_bytes()).hexdigest() - - def add(self, blob: Blob) -> bool: - with self.conn.cursor() as cur: - hash = self.__hash(blob) - cur.execute( - """ - SELECT * FROM document - WHERE - hash = %s - """, - (hash,), - ) - exist = cur.fetchone() - if exist is None: - log.debug("Inserting document hash into documents db...") - cur.execute( - """ - INSERT INTO document - (hash) VALUES - (%s) - """, - (hash,), - ) - self.conn.commit() - return exist is None diff --git a/rag/db/vector.py b/rag/db/vector.py deleted file mode 100644 index fd2b2c2..0000000 --- a/rag/db/vector.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Dict, List - -from loguru import logger as log -from qdrant_client import QdrantClient -from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, PointStruct, VectorParams - - -@dataclass -class Point: - id: str - vector: List[StrictFloat] - payload: Dict[str, str] - - -@dataclass -class Document: - title: str - text: str - - -class VectorDB: - def __init__(self, score_threshold: float = 0.6): - self.dim = int(os.environ["EMBEDDING_DIM"]) - self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] - self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.score_threshold = score_threshold - self.__configure() - - def __configure(self): - collections = list( - map(lambda col: col.name, self.client.get_collections().collections) - ) - if self.collection_name not in collections: - log.debug(f"Configuring collection {self.collection_name}...") - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), - ) - else: - log.debug(f"Collection {self.collection_name} already exists...") - - def add(self, points: List[Point]): - log.debug(f"Inserting {len(points)} vectors into the vector db...") - self.client.upload_points( - collection_name=self.collection_name, - points=[ - PointStruct(id=point.id, vector=point.vector, payload=point.payload) - for point in points - ], - parallel=4, - max_retries=3, - ) - - def search(self, query: List[float], limit: int = 5) -> List[Document]: - log.debug("Searching for vectors...") - hits = self.client.search( - collection_name=self.collection_name, - query_vector=query, - limit=limit, - score_threshold=self.score_threshold, - ) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return list( - map( - lambda h: Document( - title=h.payload.get("source", ""), text=h.payload["text"] - ), - hits, - ) - ) |