diff options
Diffstat (limited to 'rag/rag.py')
-rw-r--r-- | rag/rag.py | 34 |
1 files changed, 16 insertions, 18 deletions
@@ -5,20 +5,22 @@ from typing import List from dotenv import load_dotenv from loguru import logger as log -from qdrant_client.models import StrictFloat + try: - from rag.db.vector import VectorDB + from rag.db.vector import VectorDB, Document from rag.db.document import DocumentDB from rag.llm.encoder import Encoder - from rag.llm.generator import Generator, Prompt + from rag.llm.ollama_generator import OllamaGenerator, Prompt + from rag.llm.cohere_generator import CohereGenerator from rag.parser.pdf import PDFParser except ModuleNotFoundError: - from db.vector import VectorDB + from db.vector import VectorDB, Document from db.document import DocumentDB from llm.encoder import Encoder - from llm.generator import Generator, Prompt + from llm.ollama_generator import OllamaGenerator, Prompt + from llm.cohere_generator import CohereGenerator from parser.pdf import PDFParser @@ -34,7 +36,7 @@ class RAG: # FIXME: load this somewhere else? load_dotenv() self.pdf_parser = PDFParser() - self.generator = Generator() + self.generator = CohereGenerator() self.encoder = Encoder() self.vector_db = VectorDB() self.doc_db = DocumentDB() @@ -43,23 +45,19 @@ class RAG: blob = self.pdf_parser.from_path(path) self.add_pdf_from_blob(blob) - def add_pdf_from_blob(self, blob: BytesIO): + def add_pdf_from_blob(self, blob: BytesIO, source: str): if self.doc_db.add(blob): log.debug("Adding pdf to vector database...") - chunks = self.pdf_parser.from_data(blob) + document = self.pdf_parser.from_data(blob) + chunks = self.pdf_parser.chunk(document, source) points = self.encoder.encode_document(chunks) self.vector_db.add(points) else: log.debug("Document already exists!") - def __context(self, query_emb: List[StrictFloat], limit: int) -> str: - hits = self.vector_db.search(query_emb, limit) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return [h.payload["text"] for h in hits] - - def retrive(self, query: str, limit: int = 5) -> Response: + def search(self, query: str, limit: int = 5) -> List[Document]: query_emb = self.encoder.encode_query(query) - context = self.__context(query_emb, limit) - prompt = Prompt(query, "\n".join(context)) - answer = self.generator.generate(prompt)["response"] - return Response(query, context, answer) + return self.vector_db.search(query_emb, limit) + + def retrieve(self, prompt: Prompt): + yield from self.generator.generate(prompt) |