summaryrefslogtreecommitdiff
path: root/rag/retriever/retriever.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/retriever.py')
-rw-r--r--rag/retriever/retriever.py57
1 files changed, 33 insertions, 24 deletions
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index 351cfb0..7d43941 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional
@@ -5,11 +6,25 @@ from typing import List, Optional
from loguru import logger as log
from .document import DocumentDB
-from .encoder import Encoder
+from .encoder import Encoder, Query
from .parser.pdf import PDFParser
from .vector import Document, VectorDB
+@dataclass
+class FilePath:
+ path: Path
+
+
+@dataclass
+class Blob:
+ blob: BytesIO
+ source: Optional[str] = None
+
+
+FileType = FilePath | Blob
+
+
class Retriever:
def __init__(self) -> None:
self.pdf_parser = PDFParser()
@@ -17,35 +32,29 @@ class Retriever:
self.doc_db = DocumentDB()
self.vec_db = VectorDB()
- def __add_pdf_from_path(self, path: Path):
- log.debug(f"Adding pdf from {path}")
+ def __index_pdf_from_path(self, path: Path):
+ log.debug(f"Indexing pdf from {path}")
blob = self.pdf_parser.from_path(path)
- self.__add_pdf_from_blob(blob)
+ self.__index_pdf_from_blob(blob, None)
- def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
- if self.doc_db.add(blob):
- log.debug("Adding pdf to vector database...")
+ def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]):
+ if self.doc_db.create(blob):
+ log.debug("Indexing pdf to vector database...")
document = self.pdf_parser.from_data(blob)
chunks = self.pdf_parser.chunk(document, source)
- points = self.encoder.encode_document(chunks)
- self.vec_db.add(points)
+ points = self.encoder.encode(chunks)
+ self.vec_db.index(points)
else:
log.debug("Document already exists!")
- def add_pdf(
- self,
- path: Optional[Path] = None,
- blob: Optional[BytesIO] = None,
- source: Optional[str] = None,
- ):
- if path:
- self.__add_pdf_from_path(path)
- elif blob and source:
- self.__add_pdf_from_blob(blob, source)
- else:
- log.error("Invalid input!")
+ def index(self, filetype: FileType):
+ match filetype:
+ case FilePath(path):
+ self.__index_pdf_from_path(path)
+ case Blob(blob, source):
+ self.__index_pdf_from_blob(blob, source)
- def retrieve(self, query: str) -> List[Document]:
- log.debug(f"Finding documents matching query: {query}")
- query_emb = self.encoder.encode_query(query)
+ def search(self, query: Query) -> List[Document]:
+ log.debug(f"Finding documents matching query: {query.query}")
+ query_emb = self.encoder.encode(query)
return self.vec_db.search(query_emb)