diff options
Diffstat (limited to 'rag/retriever/retriever.py')
-rw-r--r-- | rag/retriever/retriever.py | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index dbfdfa2..885dafe 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -16,12 +16,12 @@ class Retriever: self.doc_db = DocumentDB() self.vec_db = VectorDB() - def add_pdf_from_path(self, path: Path): + def __add_pdf_from_path(self, path: Path): log.debug(f"Adding pdf from {path}") blob = self.pdf_parser.from_path(path) - self.add_pdf_from_blob(blob) + self.__add_pdf_from_blob(blob) - def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): + def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): if self.doc_db.add(blob): log.debug("Adding pdf to vector database...") document = self.pdf_parser.from_data(blob) @@ -31,6 +31,19 @@ class Retriever: else: log.debug("Document already exists!") + def add_pdf( + self, + path: Optional[Path] = None, + blob: Optional[BytesIO] = None, + source: Optional[str] = None, + ): + if path: + self.__add_pdf_from_path(path) + elif blob and source: + self.__add_pdf_from_blob(blob, source) + else: + log.error("Invalid input!") + def retrieve(self, query: str, limit: int = 5) -> List[Document]: log.debug(f"Finding documents matching query: {query}") query_emb = self.encoder.encode_query(query) |