diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 13:15:07 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 13:15:07 +0200 |
commit | 052bf63a2c18b1b55013dcf6974228609cc4d76f (patch) | |
tree | 1846b4c5555ca504bfb638f72bee14249f502577 /rag | |
parent | d116abc63e350b092c2a7f9e1bb9b54298e21b2d (diff) |
Refactor pdf reader
Diffstat (limited to 'rag')
-rw-r--r-- | rag/db/document.py | 11 | ||||
-rw-r--r-- | rag/parser/pdf.py | 34 | ||||
-rw-r--r-- | rag/rag.py | 14 |
3 files changed, 29 insertions, 30 deletions
diff --git a/rag/db/document.py b/rag/db/document.py index 763eb11..b657e55 100644 --- a/rag/db/document.py +++ b/rag/db/document.py @@ -1,9 +1,7 @@ import hashlib import os -from typing import List import psycopg -from langchain_core.documents.base import Document from loguru import logger as log TABLES = """ @@ -28,14 +26,13 @@ class DocumentDB: cur.execute(TABLES) self.conn.commit() - def __hash(self, chunks: List[Document]) -> str: + def __hash(self, blob: bytes) -> str: log.debug("Hashing document...") - document = str.encode("".join([chunk.page_content for chunk in chunks])) - return hashlib.sha256(document).hexdigest() + return hashlib.sha256(blob).hexdigest() - def add(self, chunks: List[Document]) -> bool: + def add(self, blob: bytes) -> bool: with self.conn.cursor() as cur: - hash = self.__hash(chunks) + hash = self.__hash(blob) cur.execute( """ SELECT * FROM document diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py index 1680a47..22fc4e0 100644 --- a/rag/parser/pdf.py +++ b/rag/parser/pdf.py @@ -1,18 +1,34 @@ import os from pathlib import Path +from typing import Iterator, Optional from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader +from langchain_core.documents import Document +from langchain_community.document_loaders.parsers.pdf import ( + PyPDFParser, +) +from rag.db.document import DocumentDB -def parser(filepath: Path): - content = PyPDFLoader(filepath).load() - splitter = RecursiveCharacterTextSplitter( - chunk_size=int(os.environ["CHUNK_SIZE"]), - chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), - ) - chunks = splitter.split_documents(content) - return chunks +class PDF: + def __init__(self) -> None: + self.db = DocumentDB() + self.parser = PyPDFParser(password=None, extract_images=False) + def from_data(self, blob) -> Optional[Iterator[Document]]: + if self.db.add(blob): + yield from self.parser.parse(blob) + yield None -# TODO: add parser for bytearray + def from_path(self, file_path: Path) -> Optional[Iterator[Document]]: + blob = Blob.from_path(file_path) + from_data(blob) + + def chunk(self, content: Iterator[Document]): + splitter = RecursiveCharacterTextSplitter( + chunk_size=int(os.environ["CHUNK_SIZE"]), + chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), + ) + chunks = splitter.split_documents(content) + return chunks @@ -1,15 +1,12 @@ -from pathlib import Path from typing import List from dotenv import load_dotenv from loguru import logger as log from qdrant_client.models import StrictFloat -from rag.db.document import DocumentDB from rag.db.vector import VectorDB from rag.llm.encoder import Encoder from rag.llm.generator import Generator, Prompt -from rag.parser import pdf class RAG: @@ -17,19 +14,8 @@ class RAG: load_dotenv() self.generator = Generator() self.encoder = Encoder() - self.document_db = DocumentDB() self.vector_db = VectorDB() - def add_pdf(self, filepath: Path): - chunks = pdf.parser(filepath) - added = self.document_db.add(chunks) - if added: - log.debug(f"Adding pdf with filepath: {filepath} to vector db") - points = self.encoder.encode_document(chunks) - self.vector_db.add(points) - else: - log.debug("Document already exists!") - def __context(self, query_emb: List[StrictFloat], limit: int) -> str: hits = self.vector_db.search(query_emb, limit) log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") |