summaryrefslogtreecommitdiff
path: root/rag/retriever/retriever.py
blob: 7d43941bee0edf8c67fb0b28459d9adcd589ba99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional

from loguru import logger as log

from .document import DocumentDB
from .encoder import Encoder, Query
from .parser.pdf import PDFParser
from .vector import Document, VectorDB


@dataclass
class FilePath:
    path: Path


@dataclass
class Blob:
    blob: BytesIO
    source: Optional[str] = None


FileType = FilePath | Blob


class Retriever:
    def __init__(self) -> None:
        self.pdf_parser = PDFParser()
        self.encoder = Encoder()
        self.doc_db = DocumentDB()
        self.vec_db = VectorDB()

    def __index_pdf_from_path(self, path: Path):
        log.debug(f"Indexing pdf from {path}")
        blob = self.pdf_parser.from_path(path)
        self.__index_pdf_from_blob(blob, None)

    def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]):
        if self.doc_db.create(blob):
            log.debug("Indexing pdf to vector database...")
            document = self.pdf_parser.from_data(blob)
            chunks = self.pdf_parser.chunk(document, source)
            points = self.encoder.encode(chunks)
            self.vec_db.index(points)
        else:
            log.debug("Document already exists!")

    def index(self, filetype: FileType):
        match filetype:
            case FilePath(path):
                self.__index_pdf_from_path(path)
            case Blob(blob, source):
                self.__index_pdf_from_blob(blob, source)

    def search(self, query: Query) -> List[Document]:
        log.debug(f"Finding documents matching query: {query.query}")
        query_emb = self.encoder.encode(query)
        return self.vec_db.search(query_emb)