summaryrefslogtreecommitdiff
path: root/rag/retriever/retriever.py
blob: 885dafea530815f633cf9a89edc3e0f95e303a45 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pathlib import Path
from typing import Optional, List
from loguru import logger as log

from io import BytesIO
from .document import DocumentDB
from .encoder import Encoder
from .parser.pdf import PDFParser
from .vector import VectorDB, Document


class Retriever:
    def __init__(self) -> None:
        self.pdf_parser = PDFParser()
        self.encoder = Encoder()
        self.doc_db = DocumentDB()
        self.vec_db = VectorDB()

    def __add_pdf_from_path(self, path: Path):
        log.debug(f"Adding pdf from {path}")
        blob = self.pdf_parser.from_path(path)
        self.__add_pdf_from_blob(blob)

    def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
        if self.doc_db.add(blob):
            log.debug("Adding pdf to vector database...")
            document = self.pdf_parser.from_data(blob)
            chunks = self.pdf_parser.chunk(document, source)
            points = self.encoder.encode_document(chunks)
            self.vec_db.add(points)
        else:
            log.debug("Document already exists!")

    def add_pdf(
        self,
        path: Optional[Path] = None,
        blob: Optional[BytesIO] = None,
        source: Optional[str] = None,
    ):
        if path:
            self.__add_pdf_from_path(path)
        elif blob and source:
            self.__add_pdf_from_blob(blob, source)
        else:
            log.error("Invalid input!")

    def retrieve(self, query: str, limit: int = 5) -> List[Document]:
        log.debug(f"Finding documents matching query: {query}")
        query_emb = self.encoder.encode_query(query)
        return self.vec_db.search(query_emb, limit)