summaryrefslogtreecommitdiff
path: root/rag/retriever
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
commit91ddb3672e514fa9824609ff047d7cab0c65631a (patch)
tree009fd82618588d2960b5207128e86875f73cccdc /rag/retriever
parentd487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff)
Refactor
Diffstat (limited to 'rag/retriever')
-rw-r--r--rag/retriever/__init__.py0
-rw-r--r--rag/retriever/document.py57
-rw-r--r--rag/retriever/encoder.py43
-rw-r--r--rag/retriever/parser/__init__.py0
-rw-r--r--rag/retriever/parser/pdf.py34
-rw-r--r--rag/retriever/retriever.py37
-rw-r--r--rag/retriever/vector.py73
7 files changed, 244 insertions, 0 deletions
diff --git a/rag/retriever/__init__.py b/rag/retriever/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/retriever/__init__.py
diff --git a/rag/retriever/document.py b/rag/retriever/document.py
new file mode 100644
index 0000000..54ac451
--- /dev/null
+++ b/rag/retriever/document.py
@@ -0,0 +1,57 @@
+import hashlib
+import os
+
+from langchain_community.document_loaders.blob_loaders import Blob
+import psycopg
+from loguru import logger as log
+
+TABLES = """
+CREATE TABLE IF NOT EXISTS document (
+ hash text PRIMARY KEY)
+"""
+
+
+class DocumentDB:
+ def __init__(self) -> None:
+ self.conn = psycopg.connect(
+ f"dbname={os.environ['DOCUMENT_DB_NAME']} user={os.environ['DOCUMENT_DB_USER']}"
+ )
+ self.__configure()
+
+ def close(self):
+ self.conn.close()
+
+ def __configure(self):
+ log.debug("Creating documents table if it does not exist...")
+ with self.conn.cursor() as cur:
+ cur.execute(TABLES)
+ self.conn.commit()
+
+ def __hash(self, blob: Blob) -> str:
+ log.debug("Hashing document...")
+ return hashlib.sha256(blob.as_bytes()).hexdigest()
+
+ def add(self, blob: Blob) -> bool:
+ with self.conn.cursor() as cur:
+ hash = self.__hash(blob)
+ cur.execute(
+ """
+ SELECT * FROM document
+ WHERE
+ hash = %s
+ """,
+ (hash,),
+ )
+ exist = cur.fetchone()
+ if exist is None:
+ log.debug("Inserting document hash into documents db...")
+ cur.execute(
+ """
+ INSERT INTO document
+ (hash) VALUES
+ (%s)
+ """,
+ (hash,),
+ )
+ self.conn.commit()
+ return exist is None
diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py
new file mode 100644
index 0000000..753157f
--- /dev/null
+++ b/rag/retriever/encoder.py
@@ -0,0 +1,43 @@
+import os
+from pathlib import Path
+from typing import List, Dict
+from uuid import uuid4
+
+import ollama
+from langchain_core.documents import Document
+from loguru import logger as log
+from qdrant_client.http.models import StrictFloat
+
+from .vector import Point
+
+
+class Encoder:
+ def __init__(self) -> None:
+ self.model = os.environ["ENCODER_MODEL"]
+ self.query_prompt = "Represent this sentence for searching relevant passages: "
+
+ def __encode(self, prompt: str) -> List[StrictFloat]:
+ return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+
+ def __get_source(self, metadata: Dict[str, str]) -> str:
+ source = metadata["source"]
+ return Path(source).name
+
+ def encode_document(self, chunks: List[Document]) -> List[Point]:
+ log.debug("Encoding document...")
+ return [
+ Point(
+ id=uuid4().hex,
+ vector=self.__encode(chunk.page_content),
+ payload={
+ "text": chunk.page_content,
+ "source": self.__get_source(chunk.metadata),
+ },
+ )
+ for chunk in chunks
+ ]
+
+ def encode_query(self, query: str) -> List[StrictFloat]:
+ log.debug(f"Encoding query: {query}")
+ query = self.query_prompt + query
+ return self.__encode(query)
diff --git a/rag/retriever/parser/__init__.py b/rag/retriever/parser/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/retriever/parser/__init__.py
diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py
new file mode 100644
index 0000000..410f027
--- /dev/null
+++ b/rag/retriever/parser/pdf.py
@@ -0,0 +1,34 @@
+import os
+from pathlib import Path
+from typing import List, Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_core.documents import Document
+from langchain_community.document_loaders.parsers.pdf import (
+ PyPDFParser,
+)
+from langchain_community.document_loaders.blob_loaders import Blob
+
+
+class PDFParser:
+ def __init__(self) -> None:
+ self.parser = PyPDFParser(password=None, extract_images=False)
+
+ def from_data(self, blob: Blob) -> List[Document]:
+ return self.parser.parse(blob)
+
+ def from_path(self, path: Path) -> Blob:
+ return Blob.from_path(path)
+
+ def chunk(
+ self, document: List[Document], source: Optional[str] = None
+ ) -> List[Document]:
+ splitter = RecursiveCharacterTextSplitter(
+ chunk_size=int(os.environ["CHUNK_SIZE"]),
+ chunk_overlap=int(os.environ["CHUNK_OVERLAP"]),
+ )
+ chunks = splitter.split_documents(document)
+ if source is not None:
+ for c in chunks:
+ c.metadata["source"] = source
+ return chunks
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
new file mode 100644
index 0000000..dbfdfa2
--- /dev/null
+++ b/rag/retriever/retriever.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+from typing import Optional, List
+from loguru import logger as log
+
+from io import BytesIO
+from .document import DocumentDB
+from .encoder import Encoder
+from .parser.pdf import PDFParser
+from .vector import VectorDB, Document
+
+
+class Retriever:
+ def __init__(self) -> None:
+ self.pdf_parser = PDFParser()
+ self.encoder = Encoder()
+ self.doc_db = DocumentDB()
+ self.vec_db = VectorDB()
+
+ def add_pdf_from_path(self, path: Path):
+ log.debug(f"Adding pdf from {path}")
+ blob = self.pdf_parser.from_path(path)
+ self.add_pdf_from_blob(blob)
+
+ def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
+ if self.doc_db.add(blob):
+ log.debug("Adding pdf to vector database...")
+ document = self.pdf_parser.from_data(blob)
+ chunks = self.pdf_parser.chunk(document, source)
+ points = self.encoder.encode_document(chunks)
+ self.vec_db.add(points)
+ else:
+ log.debug("Document already exists!")
+
+ def retrieve(self, query: str, limit: int = 5) -> List[Document]:
+ log.debug(f"Finding documents matching query: {query}")
+ query_emb = self.encoder.encode_query(query)
+ return self.vec_db.search(query_emb, limit)
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
new file mode 100644
index 0000000..fd2b2c2
--- /dev/null
+++ b/rag/retriever/vector.py
@@ -0,0 +1,73 @@
+import os
+from dataclasses import dataclass
+from typing import Dict, List
+
+from loguru import logger as log
+from qdrant_client import QdrantClient
+from qdrant_client.http.models import StrictFloat
+from qdrant_client.models import Distance, PointStruct, VectorParams
+
+
+@dataclass
+class Point:
+ id: str
+ vector: List[StrictFloat]
+ payload: Dict[str, str]
+
+
+@dataclass
+class Document:
+ title: str
+ text: str
+
+
+class VectorDB:
+ def __init__(self, score_threshold: float = 0.6):
+ self.dim = int(os.environ["EMBEDDING_DIM"])
+ self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
+ self.client = QdrantClient(url=os.environ["QDRANT_URL"])
+ self.score_threshold = score_threshold
+ self.__configure()
+
+ def __configure(self):
+ collections = list(
+ map(lambda col: col.name, self.client.get_collections().collections)
+ )
+ if self.collection_name not in collections:
+ log.debug(f"Configuring collection {self.collection_name}...")
+ self.client.create_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
+ )
+ else:
+ log.debug(f"Collection {self.collection_name} already exists...")
+
+ def add(self, points: List[Point]):
+ log.debug(f"Inserting {len(points)} vectors into the vector db...")
+ self.client.upload_points(
+ collection_name=self.collection_name,
+ points=[
+ PointStruct(id=point.id, vector=point.vector, payload=point.payload)
+ for point in points
+ ],
+ parallel=4,
+ max_retries=3,
+ )
+
+ def search(self, query: List[float], limit: int = 5) -> List[Document]:
+ log.debug("Searching for vectors...")
+ hits = self.client.search(
+ collection_name=self.collection_name,
+ query_vector=query,
+ limit=limit,
+ score_threshold=self.score_threshold,
+ )
+ log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
+ return list(
+ map(
+ lambda h: Document(
+ title=h.payload.get("source", ""), text=h.payload["text"]
+ ),
+ hits,
+ )
+ )