diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/rag.py | |
parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) |
Refactor
Diffstat (limited to 'rag/rag.py')
-rw-r--r-- | rag/rag.py | 63 |
1 files changed, 26 insertions, 37 deletions
@@ -1,27 +1,22 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import List +from typing import List, Optional, Type from dotenv import load_dotenv from loguru import logger as log - try: - from rag.db.vector import VectorDB, Document - from rag.db.document import DocumentDB - from rag.llm.encoder import Encoder - from rag.llm.ollama_generator import OllamaGenerator, Prompt - from rag.llm.cohere_generator import CohereGenerator - from rag.parser.pdf import PDFParser + from rag.retriever.vector import Document + from rag.generator.abstract import AbstractGenerator + from rag.retriever.retriever import Retriever + from rag.generator.prompt import Prompt except ModuleNotFoundError: - from db.vector import VectorDB, Document - from db.document import DocumentDB - from llm.encoder import Encoder - from llm.ollama_generator import OllamaGenerator, Prompt - from llm.cohere_generator import CohereGenerator - from parser.pdf import PDFParser + from retriever.vector import Document + from generator.abstract import AbstractGenerator + from retriever.retriever import Retriever + from generator.prompt import Prompt @dataclass @@ -35,29 +30,23 @@ class RAG: def __init__(self) -> None: # FIXME: load this somewhere else? load_dotenv() - self.pdf_parser = PDFParser() - self.generator = CohereGenerator() - self.encoder = Encoder() - self.vector_db = VectorDB() - self.doc_db = DocumentDB() - - def add_pdf_from_path(self, path: Path): - blob = self.pdf_parser.from_path(path) - self.add_pdf_from_blob(blob) - - def add_pdf_from_blob(self, blob: BytesIO, source: str): - if self.doc_db.add(blob): - log.debug("Adding pdf to vector database...") - document = self.pdf_parser.from_data(blob) - chunks = self.pdf_parser.chunk(document, source) - points = self.encoder.encode_document(chunks) - self.vector_db.add(points) + self.retriever = Retriever() + + def add_pdf( + self, + path: Optional[Path], + blob: Optional[BytesIO], + source: Optional[str] = None, + ): + if path: + self.retriever.add_pdf_from_path(path) + elif blob: + self.retriever.add_pdf_from_blob(blob, source) else: - log.debug("Document already exists!") + log.error("Both path and blob was None, no pdf added!") - def search(self, query: str, limit: int = 5) -> List[Document]: - query_emb = self.encoder.encode_query(query) - return self.vector_db.search(query_emb, limit) + def retrieve(self, query: str, limit: int = 5) -> List[Document]: + return self.retriever.retrieve(query, limit) - def retrieve(self, prompt: Prompt): - yield from self.generator.generate(prompt) + def generate(self, generator: Type[AbstractGenerator], prompt: Prompt): + yield from generator.generate(prompt) |