diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-08 00:23:52 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-08 00:23:52 +0200 |
commit | 8211705debf9d1335223c606275f46c43c78d8a2 (patch) | |
tree | f09f902c7d31b2035813c42cbb4a47e720fa363b /rag/rag.py | |
parent | 95f47c4900a96d91daaef93bf87094ed3d4da43c (diff) |
Updates
Diffstat (limited to 'rag/rag.py')
-rw-r--r-- | rag/rag.py | 52 |
1 files changed, 37 insertions, 15 deletions
@@ -1,3 +1,5 @@ +from dataclasses import dataclass +from io import BytesIO from pathlib import Path from typing import List @@ -5,27 +7,46 @@ from dotenv import load_dotenv from loguru import logger as log from qdrant_client.models import StrictFloat -from rag.db.document import DocumentDB -from rag.db.vector import VectorDB -from rag.llm.encoder import Encoder -from rag.llm.generator import Generator, Prompt -from rag.parser import pdf + +try: + from rag.db.vector import VectorDB + from rag.db.document import DocumentDB + from rag.llm.encoder import Encoder + from rag.llm.generator import Generator, Prompt + from rag.parser.pdf import PDFParser +except ModuleNotFoundError: + from db.vector import VectorDB + from db.document import DocumentDB + from llm.encoder import Encoder + from llm.generator import Generator, Prompt + from parser.pdf import PDFParser + + +@dataclass +class Response: + query: str + context: List[str] + answer: str class RAG: def __init__(self) -> None: # FIXME: load this somewhere else? load_dotenv() + self.pdf_parser = PDFParser() self.generator = Generator() self.encoder = Encoder() self.vector_db = VectorDB() + self.doc_db = DocumentDB() + + def add_pdf_from_path(self, path: Path): + blob = self.pdf_parser.from_path(path) + self.add_pdf_from_blob(blob) - # FIXME: refactor this, add vector? - def add_pdf(self, filepath: Path): - chunks = pdf.parser(filepath) - added = self.document_db.add(chunks) - if added: - log.debug(f"Adding pdf with filepath: {filepath} to vector db") + def add_pdf_from_blob(self, blob: BytesIO): + if self.doc_db.add(blob): + log.debug("Adding pdf to vector database...") + chunks = self.pdf_parser.from_data(blob) points = self.encoder.encode_document(chunks) self.vector_db.add(points) else: @@ -34,10 +55,11 @@ class RAG: def __context(self, query_emb: List[StrictFloat], limit: int) -> str: hits = self.vector_db.search(query_emb, limit) log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return "\n".join(h.payload["text"] for h in hits) + return [h.payload["text"] for h in hits] - def rag(self, query: str, role: str, limit: int = 5) -> str: + def retrive(self, query: str, limit: int = 5) -> Response: query_emb = self.encoder.encode_query(query) context = self.__context(query_emb, limit) - prompt = Prompt(query, context) - return self.generator.generate(prompt, role)["response"] + prompt = Prompt(query, "\n".join(context)) + answer = self.generator.generate(prompt)["response"] + return Response(query, context, answer) |