diff options
Diffstat (limited to 'rag/retriever/retriever.py')
-rw-r--r-- | rag/retriever/retriever.py | 57 |
1 files changed, 33 insertions, 24 deletions
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index 351cfb0..7d43941 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import List, Optional @@ -5,11 +6,25 @@ from typing import List, Optional from loguru import logger as log from .document import DocumentDB -from .encoder import Encoder +from .encoder import Encoder, Query from .parser.pdf import PDFParser from .vector import Document, VectorDB +@dataclass +class FilePath: + path: Path + + +@dataclass +class Blob: + blob: BytesIO + source: Optional[str] = None + + +FileType = FilePath | Blob + + class Retriever: def __init__(self) -> None: self.pdf_parser = PDFParser() @@ -17,35 +32,29 @@ class Retriever: self.doc_db = DocumentDB() self.vec_db = VectorDB() - def __add_pdf_from_path(self, path: Path): - log.debug(f"Adding pdf from {path}") + def __index_pdf_from_path(self, path: Path): + log.debug(f"Indexing pdf from {path}") blob = self.pdf_parser.from_path(path) - self.__add_pdf_from_blob(blob) + self.__index_pdf_from_blob(blob, None) - def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): - if self.doc_db.add(blob): - log.debug("Adding pdf to vector database...") + def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]): + if self.doc_db.create(blob): + log.debug("Indexing pdf to vector database...") document = self.pdf_parser.from_data(blob) chunks = self.pdf_parser.chunk(document, source) - points = self.encoder.encode_document(chunks) - self.vec_db.add(points) + points = self.encoder.encode(chunks) + self.vec_db.index(points) else: log.debug("Document already exists!") - def add_pdf( - self, - path: Optional[Path] = None, - blob: Optional[BytesIO] = None, - source: Optional[str] = None, - ): - if path: - self.__add_pdf_from_path(path) - elif blob and source: - self.__add_pdf_from_blob(blob, source) - else: - log.error("Invalid input!") + def index(self, filetype: FileType): + match filetype: + case FilePath(path): + self.__index_pdf_from_path(path) + case Blob(blob, source): + self.__index_pdf_from_blob(blob, source) - def retrieve(self, query: str) -> List[Document]: - log.debug(f"Finding documents matching query: {query}") - query_emb = self.encoder.encode_query(query) + def search(self, query: Query) -> List[Document]: + log.debug(f"Finding documents matching query: {query.query}") + query_emb = self.encoder.encode(query) return self.vec_db.search(query_emb) |