diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/retriever | |
parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) |
Refactor
Diffstat (limited to 'rag/retriever')
-rw-r--r-- | rag/retriever/__init__.py | 0 | ||||
-rw-r--r-- | rag/retriever/document.py | 57 | ||||
-rw-r--r-- | rag/retriever/encoder.py | 43 | ||||
-rw-r--r-- | rag/retriever/parser/__init__.py | 0 | ||||
-rw-r--r-- | rag/retriever/parser/pdf.py | 34 | ||||
-rw-r--r-- | rag/retriever/retriever.py | 37 | ||||
-rw-r--r-- | rag/retriever/vector.py | 73 |
7 files changed, 244 insertions, 0 deletions
diff --git a/rag/retriever/__init__.py b/rag/retriever/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/retriever/__init__.py diff --git a/rag/retriever/document.py b/rag/retriever/document.py new file mode 100644 index 0000000..54ac451 --- /dev/null +++ b/rag/retriever/document.py @@ -0,0 +1,57 @@ +import hashlib +import os + +from langchain_community.document_loaders.blob_loaders import Blob +import psycopg +from loguru import logger as log + +TABLES = """ +CREATE TABLE IF NOT EXISTS document ( + hash text PRIMARY KEY) +""" + + +class DocumentDB: + def __init__(self) -> None: + self.conn = psycopg.connect( + f"dbname={os.environ['DOCUMENT_DB_NAME']} user={os.environ['DOCUMENT_DB_USER']}" + ) + self.__configure() + + def close(self): + self.conn.close() + + def __configure(self): + log.debug("Creating documents table if it does not exist...") + with self.conn.cursor() as cur: + cur.execute(TABLES) + self.conn.commit() + + def __hash(self, blob: Blob) -> str: + log.debug("Hashing document...") + return hashlib.sha256(blob.as_bytes()).hexdigest() + + def add(self, blob: Blob) -> bool: + with self.conn.cursor() as cur: + hash = self.__hash(blob) + cur.execute( + """ + SELECT * FROM document + WHERE + hash = %s + """, + (hash,), + ) + exist = cur.fetchone() + if exist is None: + log.debug("Inserting document hash into documents db...") + cur.execute( + """ + INSERT INTO document + (hash) VALUES + (%s) + """, + (hash,), + ) + self.conn.commit() + return exist is None diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py new file mode 100644 index 0000000..753157f --- /dev/null +++ b/rag/retriever/encoder.py @@ -0,0 +1,43 @@ +import os +from pathlib import Path +from typing import List, Dict +from uuid import uuid4 + +import ollama +from langchain_core.documents import Document +from loguru import logger as log +from qdrant_client.http.models import StrictFloat + +from .vector import Point + + +class Encoder: + def __init__(self) -> None: + self.model = os.environ["ENCODER_MODEL"] + self.query_prompt = "Represent this sentence for searching relevant passages: " + + def __encode(self, prompt: str) -> List[StrictFloat]: + return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + + def __get_source(self, metadata: Dict[str, str]) -> str: + source = metadata["source"] + return Path(source).name + + def encode_document(self, chunks: List[Document]) -> List[Point]: + log.debug("Encoding document...") + return [ + Point( + id=uuid4().hex, + vector=self.__encode(chunk.page_content), + payload={ + "text": chunk.page_content, + "source": self.__get_source(chunk.metadata), + }, + ) + for chunk in chunks + ] + + def encode_query(self, query: str) -> List[StrictFloat]: + log.debug(f"Encoding query: {query}") + query = self.query_prompt + query + return self.__encode(query) diff --git a/rag/retriever/parser/__init__.py b/rag/retriever/parser/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/retriever/parser/__init__.py diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py new file mode 100644 index 0000000..410f027 --- /dev/null +++ b/rag/retriever/parser/pdf.py @@ -0,0 +1,34 @@ +import os +from pathlib import Path +from typing import List, Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_core.documents import Document +from langchain_community.document_loaders.parsers.pdf import ( + PyPDFParser, +) +from langchain_community.document_loaders.blob_loaders import Blob + + +class PDFParser: + def __init__(self) -> None: + self.parser = PyPDFParser(password=None, extract_images=False) + + def from_data(self, blob: Blob) -> List[Document]: + return self.parser.parse(blob) + + def from_path(self, path: Path) -> Blob: + return Blob.from_path(path) + + def chunk( + self, document: List[Document], source: Optional[str] = None + ) -> List[Document]: + splitter = RecursiveCharacterTextSplitter( + chunk_size=int(os.environ["CHUNK_SIZE"]), + chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), + ) + chunks = splitter.split_documents(document) + if source is not None: + for c in chunks: + c.metadata["source"] = source + return chunks diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py new file mode 100644 index 0000000..dbfdfa2 --- /dev/null +++ b/rag/retriever/retriever.py @@ -0,0 +1,37 @@ +from pathlib import Path +from typing import Optional, List +from loguru import logger as log + +from io import BytesIO +from .document import DocumentDB +from .encoder import Encoder +from .parser.pdf import PDFParser +from .vector import VectorDB, Document + + +class Retriever: + def __init__(self) -> None: + self.pdf_parser = PDFParser() + self.encoder = Encoder() + self.doc_db = DocumentDB() + self.vec_db = VectorDB() + + def add_pdf_from_path(self, path: Path): + log.debug(f"Adding pdf from {path}") + blob = self.pdf_parser.from_path(path) + self.add_pdf_from_blob(blob) + + def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): + if self.doc_db.add(blob): + log.debug("Adding pdf to vector database...") + document = self.pdf_parser.from_data(blob) + chunks = self.pdf_parser.chunk(document, source) + points = self.encoder.encode_document(chunks) + self.vec_db.add(points) + else: + log.debug("Document already exists!") + + def retrieve(self, query: str, limit: int = 5) -> List[Document]: + log.debug(f"Finding documents matching query: {query}") + query_emb = self.encoder.encode_query(query) + return self.vec_db.search(query_emb, limit) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py new file mode 100644 index 0000000..fd2b2c2 --- /dev/null +++ b/rag/retriever/vector.py @@ -0,0 +1,73 @@ +import os +from dataclasses import dataclass +from typing import Dict, List + +from loguru import logger as log +from qdrant_client import QdrantClient +from qdrant_client.http.models import StrictFloat +from qdrant_client.models import Distance, PointStruct, VectorParams + + +@dataclass +class Point: + id: str + vector: List[StrictFloat] + payload: Dict[str, str] + + +@dataclass +class Document: + title: str + text: str + + +class VectorDB: + def __init__(self, score_threshold: float = 0.6): + self.dim = int(os.environ["EMBEDDING_DIM"]) + self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] + self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.score_threshold = score_threshold + self.__configure() + + def __configure(self): + collections = list( + map(lambda col: col.name, self.client.get_collections().collections) + ) + if self.collection_name not in collections: + log.debug(f"Configuring collection {self.collection_name}...") + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), + ) + else: + log.debug(f"Collection {self.collection_name} already exists...") + + def add(self, points: List[Point]): + log.debug(f"Inserting {len(points)} vectors into the vector db...") + self.client.upload_points( + collection_name=self.collection_name, + points=[ + PointStruct(id=point.id, vector=point.vector, payload=point.payload) + for point in points + ], + parallel=4, + max_retries=3, + ) + + def search(self, query: List[float], limit: int = 5) -> List[Document]: + log.debug("Searching for vectors...") + hits = self.client.search( + collection_name=self.collection_name, + query_vector=query, + limit=limit, + score_threshold=self.score_threshold, + ) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) + ) |