summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 13:15:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 13:15:07 +0200
commit052bf63a2c18b1b55013dcf6974228609cc4d76f (patch)
tree1846b4c5555ca504bfb638f72bee14249f502577 /rag
parentd116abc63e350b092c2a7f9e1bb9b54298e21b2d (diff)
Refactor pdf reader
Diffstat (limited to 'rag')
-rw-r--r--rag/db/document.py11
-rw-r--r--rag/parser/pdf.py34
-rw-r--r--rag/rag.py14
3 files changed, 29 insertions, 30 deletions
diff --git a/rag/db/document.py b/rag/db/document.py
index 763eb11..b657e55 100644
--- a/rag/db/document.py
+++ b/rag/db/document.py
@@ -1,9 +1,7 @@
import hashlib
import os
-from typing import List
import psycopg
-from langchain_core.documents.base import Document
from loguru import logger as log
TABLES = """
@@ -28,14 +26,13 @@ class DocumentDB:
cur.execute(TABLES)
self.conn.commit()
- def __hash(self, chunks: List[Document]) -> str:
+ def __hash(self, blob: bytes) -> str:
log.debug("Hashing document...")
- document = str.encode("".join([chunk.page_content for chunk in chunks]))
- return hashlib.sha256(document).hexdigest()
+ return hashlib.sha256(blob).hexdigest()
- def add(self, chunks: List[Document]) -> bool:
+ def add(self, blob: bytes) -> bool:
with self.conn.cursor() as cur:
- hash = self.__hash(chunks)
+ hash = self.__hash(blob)
cur.execute(
"""
SELECT * FROM document
diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py
index 1680a47..22fc4e0 100644
--- a/rag/parser/pdf.py
+++ b/rag/parser/pdf.py
@@ -1,18 +1,34 @@
import os
from pathlib import Path
+from typing import Iterator, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
+from langchain_core.documents import Document
+from langchain_community.document_loaders.parsers.pdf import (
+ PyPDFParser,
+)
+from rag.db.document import DocumentDB
-def parser(filepath: Path):
- content = PyPDFLoader(filepath).load()
- splitter = RecursiveCharacterTextSplitter(
- chunk_size=int(os.environ["CHUNK_SIZE"]),
- chunk_overlap=int(os.environ["CHUNK_OVERLAP"]),
- )
- chunks = splitter.split_documents(content)
- return chunks
+class PDF:
+ def __init__(self) -> None:
+ self.db = DocumentDB()
+ self.parser = PyPDFParser(password=None, extract_images=False)
+ def from_data(self, blob) -> Optional[Iterator[Document]]:
+ if self.db.add(blob):
+ yield from self.parser.parse(blob)
+ yield None
-# TODO: add parser for bytearray
+ def from_path(self, file_path: Path) -> Optional[Iterator[Document]]:
+ blob = Blob.from_path(file_path)
+ from_data(blob)
+
+ def chunk(self, content: Iterator[Document]):
+ splitter = RecursiveCharacterTextSplitter(
+ chunk_size=int(os.environ["CHUNK_SIZE"]),
+ chunk_overlap=int(os.environ["CHUNK_OVERLAP"]),
+ )
+ chunks = splitter.split_documents(content)
+ return chunks
diff --git a/rag/rag.py b/rag/rag.py
index 87b44c5..6826a80 100644
--- a/rag/rag.py
+++ b/rag/rag.py
@@ -1,15 +1,12 @@
-from pathlib import Path
from typing import List
from dotenv import load_dotenv
from loguru import logger as log
from qdrant_client.models import StrictFloat
-from rag.db.document import DocumentDB
from rag.db.vector import VectorDB
from rag.llm.encoder import Encoder
from rag.llm.generator import Generator, Prompt
-from rag.parser import pdf
class RAG:
@@ -17,19 +14,8 @@ class RAG:
load_dotenv()
self.generator = Generator()
self.encoder = Encoder()
- self.document_db = DocumentDB()
self.vector_db = VectorDB()
- def add_pdf(self, filepath: Path):
- chunks = pdf.parser(filepath)
- added = self.document_db.add(chunks)
- if added:
- log.debug(f"Adding pdf with filepath: {filepath} to vector db")
- points = self.encoder.encode_document(chunks)
- self.vector_db.add(points)
- else:
- log.debug("Document already exists!")
-
def __context(self, query_emb: List[StrictFloat], limit: int) -> str:
hits = self.vector_db.search(query_emb, limit)
log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")