From 9f8f6c8978cfd359d5af4c001121c80853a09dc7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 6 Apr 2024 01:22:36 +0200 Subject: Update rag --- rag/rag.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/rag/rag.py b/rag/rag.py index 5b5f5ab..1988ba7 100644 --- a/rag/rag.py +++ b/rag/rag.py @@ -1,20 +1,40 @@ from pathlib import Path -from typing import List, Optional +from typing import List -from langchain_core.documents.base import Document -from llm.encoder import Encoder -from llm.generator import Generator -from parser import pdf -from db.documents import Documents -from db.vectors import Vectors +from dotenv import load_dotenv +from loguru import logger as log +from qdrant_client.models import StrictFloat + +from rag.db.document import DocumentDB +from rag.db.vector import VectorDB +from rag.llm.encoder import Encoder +from rag.llm.generator import Generator, Prompt +from rag.parser import pdf class RAG: def __init__(self) -> None: + load_dotenv() self.generator = Generator() self.encoder = Encoder() - self.docs = Documents() - self.vectors = Vectors() + self.document_db = DocumentDB() + self.vector_db = VectorDB() - def add_pdf(self, filepath: Path) -> Optional[List[Document]]: + def add_pdf(self, filepath: Path): chunks = pdf.parser(filepath) + added = self.document_db.add_document(chunks) + if added: + log.debug(f"Adding pdf with filepath: {filepath} to vector db") + points = self.encoder.encode_document(chunks) + self.vector_db.add(points) + + def __context(self, query_emb: List[StrictFloat], limit: int) -> str: + hits = self.vector_db.search(query_emb, limit) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return "\n".join(h.payload["text"] for h in hits) + + def rag(self, query: str, role: str, limit: int = 5) -> str: + query_emb = self.encoder.encode_query(query) + context = self.__context(query_emb, limit) + prompt = Prompt(query, context) + return self.generator.generate(prompt, role)["response"] -- cgit v1.2.3-70-g09d2