summaryrefslogtreecommitdiff
path: root/rag/retriever
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever')
-rw-r--r--rag/retriever/retriever.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index dbfdfa2..885dafe 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -16,12 +16,12 @@ class Retriever:
self.doc_db = DocumentDB()
self.vec_db = VectorDB()
- def add_pdf_from_path(self, path: Path):
+ def __add_pdf_from_path(self, path: Path):
log.debug(f"Adding pdf from {path}")
blob = self.pdf_parser.from_path(path)
- self.add_pdf_from_blob(blob)
+ self.__add_pdf_from_blob(blob)
- def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
+ def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
if self.doc_db.add(blob):
log.debug("Adding pdf to vector database...")
document = self.pdf_parser.from_data(blob)
@@ -31,6 +31,19 @@ class Retriever:
else:
log.debug("Document already exists!")
+ def add_pdf(
+ self,
+ path: Optional[Path] = None,
+ blob: Optional[BytesIO] = None,
+ source: Optional[str] = None,
+ ):
+ if path:
+ self.__add_pdf_from_path(path)
+ elif blob and source:
+ self.__add_pdf_from_blob(blob, source)
+ else:
+ log.error("Invalid input!")
+
def retrieve(self, query: str, limit: int = 5) -> List[Document]:
log.debug(f"Finding documents matching query: {query}")
query_emb = self.encoder.encode_query(query)