summaryrefslogtreecommitdiff
path: root/rag/rag.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
commit91ddb3672e514fa9824609ff047d7cab0c65631a (patch)
tree009fd82618588d2960b5207128e86875f73cccdc /rag/rag.py
parentd487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff)
Refactor
Diffstat (limited to 'rag/rag.py')
-rw-r--r--rag/rag.py63
1 files changed, 26 insertions, 37 deletions
diff --git a/rag/rag.py b/rag/rag.py
index 93f9fd7..c95d93a 100644
--- a/rag/rag.py
+++ b/rag/rag.py
@@ -1,27 +1,22 @@
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
-from typing import List
+from typing import List, Optional, Type
from dotenv import load_dotenv
from loguru import logger as log
-
try:
- from rag.db.vector import VectorDB, Document
- from rag.db.document import DocumentDB
- from rag.llm.encoder import Encoder
- from rag.llm.ollama_generator import OllamaGenerator, Prompt
- from rag.llm.cohere_generator import CohereGenerator
- from rag.parser.pdf import PDFParser
+ from rag.retriever.vector import Document
+ from rag.generator.abstract import AbstractGenerator
+ from rag.retriever.retriever import Retriever
+ from rag.generator.prompt import Prompt
except ModuleNotFoundError:
- from db.vector import VectorDB, Document
- from db.document import DocumentDB
- from llm.encoder import Encoder
- from llm.ollama_generator import OllamaGenerator, Prompt
- from llm.cohere_generator import CohereGenerator
- from parser.pdf import PDFParser
+ from retriever.vector import Document
+ from generator.abstract import AbstractGenerator
+ from retriever.retriever import Retriever
+ from generator.prompt import Prompt
@dataclass
@@ -35,29 +30,23 @@ class RAG:
def __init__(self) -> None:
# FIXME: load this somewhere else?
load_dotenv()
- self.pdf_parser = PDFParser()
- self.generator = CohereGenerator()
- self.encoder = Encoder()
- self.vector_db = VectorDB()
- self.doc_db = DocumentDB()
-
- def add_pdf_from_path(self, path: Path):
- blob = self.pdf_parser.from_path(path)
- self.add_pdf_from_blob(blob)
-
- def add_pdf_from_blob(self, blob: BytesIO, source: str):
- if self.doc_db.add(blob):
- log.debug("Adding pdf to vector database...")
- document = self.pdf_parser.from_data(blob)
- chunks = self.pdf_parser.chunk(document, source)
- points = self.encoder.encode_document(chunks)
- self.vector_db.add(points)
+ self.retriever = Retriever()
+
+ def add_pdf(
+ self,
+ path: Optional[Path],
+ blob: Optional[BytesIO],
+ source: Optional[str] = None,
+ ):
+ if path:
+ self.retriever.add_pdf_from_path(path)
+ elif blob:
+ self.retriever.add_pdf_from_blob(blob, source)
else:
- log.debug("Document already exists!")
+ log.error("Both path and blob was None, no pdf added!")
- def search(self, query: str, limit: int = 5) -> List[Document]:
- query_emb = self.encoder.encode_query(query)
- return self.vector_db.search(query_emb, limit)
+ def retrieve(self, query: str, limit: int = 5) -> List[Document]:
+ return self.retriever.retrieve(query, limit)
- def retrieve(self, prompt: Prompt):
- yield from self.generator.generate(prompt)
+ def generate(self, generator: Type[AbstractGenerator], prompt: Prompt):
+ yield from generator.generate(prompt)