summaryrefslogtreecommitdiff
path: root/rag/rag.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-08 00:23:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-08 00:23:52 +0200
commit8211705debf9d1335223c606275f46c43c78d8a2 (patch)
treef09f902c7d31b2035813c42cbb4a47e720fa363b /rag/rag.py
parent95f47c4900a96d91daaef93bf87094ed3d4da43c (diff)
Updates
Diffstat (limited to 'rag/rag.py')
-rw-r--r--rag/rag.py52
1 files changed, 37 insertions, 15 deletions
diff --git a/rag/rag.py b/rag/rag.py
index 488e30a..cd4537e 100644
--- a/rag/rag.py
+++ b/rag/rag.py
@@ -1,3 +1,5 @@
+from dataclasses import dataclass
+from io import BytesIO
from pathlib import Path
from typing import List
@@ -5,27 +7,46 @@ from dotenv import load_dotenv
from loguru import logger as log
from qdrant_client.models import StrictFloat
-from rag.db.document import DocumentDB
-from rag.db.vector import VectorDB
-from rag.llm.encoder import Encoder
-from rag.llm.generator import Generator, Prompt
-from rag.parser import pdf
+
+try:
+ from rag.db.vector import VectorDB
+ from rag.db.document import DocumentDB
+ from rag.llm.encoder import Encoder
+ from rag.llm.generator import Generator, Prompt
+ from rag.parser.pdf import PDFParser
+except ModuleNotFoundError:
+ from db.vector import VectorDB
+ from db.document import DocumentDB
+ from llm.encoder import Encoder
+ from llm.generator import Generator, Prompt
+ from parser.pdf import PDFParser
+
+
+@dataclass
+class Response:
+ query: str
+ context: List[str]
+ answer: str
class RAG:
def __init__(self) -> None:
# FIXME: load this somewhere else?
load_dotenv()
+ self.pdf_parser = PDFParser()
self.generator = Generator()
self.encoder = Encoder()
self.vector_db = VectorDB()
+ self.doc_db = DocumentDB()
+
+ def add_pdf_from_path(self, path: Path):
+ blob = self.pdf_parser.from_path(path)
+ self.add_pdf_from_blob(blob)
- # FIXME: refactor this, add vector?
- def add_pdf(self, filepath: Path):
- chunks = pdf.parser(filepath)
- added = self.document_db.add(chunks)
- if added:
- log.debug(f"Adding pdf with filepath: {filepath} to vector db")
+ def add_pdf_from_blob(self, blob: BytesIO):
+ if self.doc_db.add(blob):
+ log.debug("Adding pdf to vector database...")
+ chunks = self.pdf_parser.from_data(blob)
points = self.encoder.encode_document(chunks)
self.vector_db.add(points)
else:
@@ -34,10 +55,11 @@ class RAG:
def __context(self, query_emb: List[StrictFloat], limit: int) -> str:
hits = self.vector_db.search(query_emb, limit)
log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
- return "\n".join(h.payload["text"] for h in hits)
+ return [h.payload["text"] for h in hits]
- def rag(self, query: str, role: str, limit: int = 5) -> str:
+ def retrive(self, query: str, limit: int = 5) -> Response:
query_emb = self.encoder.encode_query(query)
context = self.__context(query_emb, limit)
- prompt = Prompt(query, context)
- return self.generator.generate(prompt, role)["response"]
+ prompt = Prompt(query, "\n".join(context))
+ answer = self.generator.generate(prompt)["response"]
+ return Response(query, context, answer)