From 8211705debf9d1335223c606275f46c43c78d8a2 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 8 Apr 2024 00:23:52 +0200 Subject: Updates --- rag/cli.py | 28 ++++++++++++++++++++++++++++ rag/db/document.py | 7 ++++--- rag/db/vector.py | 5 ++++- rag/llm/encoder.py | 11 +++++++---- rag/llm/generator.py | 7 +++---- rag/main.py | 0 rag/parser/pdf.py | 18 +++++++----------- rag/rag.py | 52 +++++++++++++++++++++++++++++++++++++--------------- rag/ui.py | 39 +++++++++++++++++++++++++++++++++++---- 9 files changed, 125 insertions(+), 42 deletions(-) create mode 100644 rag/cli.py delete mode 100644 rag/main.py (limited to 'rag') diff --git a/rag/cli.py b/rag/cli.py new file mode 100644 index 0000000..5ea1a47 --- /dev/null +++ b/rag/cli.py @@ -0,0 +1,28 @@ +from pathlib import Path + + +try: + from rag.rag import RAG +except ModuleNotFoundError: + from rag import RAG + +if __name__ == "__main__": + rag = RAG() + + while True: + print("Retrieval Augmented Generation") + choice = input("1. add pdf from path\n2. Enter a query\n") + match choice: + case "1": + path = input("Enter the path to the pdf: ") + path = Path(path) + rag.add_pdf_from_path(path) + case "2": + query = input("Enter your query: ") + if query: + result = rag.retrive(query) + print("Answer: \n") + print(result.answer) + case _: + print("Invalid option!") + diff --git a/rag/db/document.py b/rag/db/document.py index 528a399..54ac451 100644 --- a/rag/db/document.py +++ b/rag/db/document.py @@ -1,6 +1,7 @@ import hashlib import os +from langchain_community.document_loaders.blob_loaders import Blob import psycopg from loguru import logger as log @@ -26,11 +27,11 @@ class DocumentDB: cur.execute(TABLES) self.conn.commit() - def __hash(self, blob: bytes) -> str: + def __hash(self, blob: Blob) -> str: log.debug("Hashing document...") - return hashlib.sha256(blob).hexdigest() + return hashlib.sha256(blob.as_bytes()).hexdigest() - def add(self, blob: bytes) -> bool: + def add(self, blob: Blob) -> bool: with self.conn.cursor() as cur: hash = self.__hash(blob) cur.execute( diff --git a/rag/db/vector.py b/rag/db/vector.py index 4aa62cc..bbbbf32 100644 --- a/rag/db/vector.py +++ b/rag/db/vector.py @@ -50,6 +50,9 @@ class VectorDB: def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: log.debug("Searching for vectors...") hits = self.client.search( - collection_name=self.collection_name, query_vector=query, limit=limit + collection_name=self.collection_name, + query_vector=query, + limit=limit, + score_threshold=0.6, ) return hits diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index 7c03dc5..95f3c6a 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import Iterator, List from uuid import uuid4 import ollama @@ -7,8 +7,11 @@ from langchain_core.documents import Document from loguru import logger as log from qdrant_client.http.models import StrictFloat -from rag.db.vector import Point +try: + from rag.db.vector import Point +except ModuleNotFoundError: + from db.vector import Point class Encoder: def __init__(self) -> None: @@ -18,11 +21,11 @@ class Encoder: def __encode(self, prompt: str) -> List[StrictFloat]: return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - def encode_document(self, chunks: List[Document]) -> List[Point]: + def encode_document(self, chunks: Iterator[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( - id=str(uuid4()), + id=uuid4().hex, vector=self.__encode(chunk.page_content), payload={"text": chunk.page_content}, ) diff --git a/rag/llm/generator.py b/rag/llm/generator.py index b0c6c40..8c7702f 100644 --- a/rag/llm/generator.py +++ b/rag/llm/generator.py @@ -15,9 +15,8 @@ class Generator: def __init__(self) -> None: self.model = os.environ["GENERATOR_MODEL"] - def __metaprompt(self, role: str, prompt: Prompt) -> str: + def __metaprompt(self, prompt: Prompt) -> str: metaprompt = ( - f"You are a {role}.\n" "Answer the following question using the provided context.\n" "If you can't find the answer, do not pretend you know it," 'but answer "I don\'t know".\n\n' @@ -28,7 +27,7 @@ class Generator: ) return metaprompt - def generate(self, prompt: Prompt, role: str) -> str: + def generate(self, prompt: Prompt) -> str: log.debug("Generating answer...") - metaprompt = self.__metaprompt(role, prompt) + metaprompt = self.__metaprompt(prompt) return ollama.generate(model=self.model, prompt=metaprompt) diff --git a/rag/main.py b/rag/main.py deleted file mode 100644 index e69de29..0000000 diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py index ed4dc8b..cbd86a3 100644 --- a/rag/parser/pdf.py +++ b/rag/parser/pdf.py @@ -1,28 +1,24 @@ import os from pathlib import Path -from typing import Iterator, Optional +from typing import Iterator from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from langchain_community.document_loaders.parsers.pdf import ( PyPDFParser, ) -from rag.db.document import DocumentDB +from langchain_community.document_loaders.blob_loaders import Blob -class PDF: +class PDFParser: def __init__(self) -> None: - self.db = DocumentDB() self.parser = PyPDFParser(password=None, extract_images=False) - def from_data(self, blob) -> Optional[Iterator[Document]]: - if self.db.add(blob): - yield from self.parser.parse(blob) - yield None + def from_data(self, blob: Blob) -> Iterator[Document]: + yield from self.parser.parse(blob) - def from_path(self, file_path: Path) -> Optional[Iterator[Document]]: - blob = Blob.from_path(file_path) - from_data(blob) + def from_path(self, path: Path) -> Iterator[Document]: + return Blob.from_path(path) def chunk(self, content: Iterator[Document]): splitter = RecursiveCharacterTextSplitter( diff --git a/rag/rag.py b/rag/rag.py index 488e30a..cd4537e 100644 --- a/rag/rag.py +++ b/rag/rag.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass +from io import BytesIO from pathlib import Path from typing import List @@ -5,27 +7,46 @@ from dotenv import load_dotenv from loguru import logger as log from qdrant_client.models import StrictFloat -from rag.db.document import DocumentDB -from rag.db.vector import VectorDB -from rag.llm.encoder import Encoder -from rag.llm.generator import Generator, Prompt -from rag.parser import pdf + +try: + from rag.db.vector import VectorDB + from rag.db.document import DocumentDB + from rag.llm.encoder import Encoder + from rag.llm.generator import Generator, Prompt + from rag.parser.pdf import PDFParser +except ModuleNotFoundError: + from db.vector import VectorDB + from db.document import DocumentDB + from llm.encoder import Encoder + from llm.generator import Generator, Prompt + from parser.pdf import PDFParser + + +@dataclass +class Response: + query: str + context: List[str] + answer: str class RAG: def __init__(self) -> None: # FIXME: load this somewhere else? load_dotenv() + self.pdf_parser = PDFParser() self.generator = Generator() self.encoder = Encoder() self.vector_db = VectorDB() + self.doc_db = DocumentDB() + + def add_pdf_from_path(self, path: Path): + blob = self.pdf_parser.from_path(path) + self.add_pdf_from_blob(blob) - # FIXME: refactor this, add vector? - def add_pdf(self, filepath: Path): - chunks = pdf.parser(filepath) - added = self.document_db.add(chunks) - if added: - log.debug(f"Adding pdf with filepath: {filepath} to vector db") + def add_pdf_from_blob(self, blob: BytesIO): + if self.doc_db.add(blob): + log.debug("Adding pdf to vector database...") + chunks = self.pdf_parser.from_data(blob) points = self.encoder.encode_document(chunks) self.vector_db.add(points) else: @@ -34,10 +55,11 @@ class RAG: def __context(self, query_emb: List[StrictFloat], limit: int) -> str: hits = self.vector_db.search(query_emb, limit) log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return "\n".join(h.payload["text"] for h in hits) + return [h.payload["text"] for h in hits] - def rag(self, query: str, role: str, limit: int = 5) -> str: + def retrive(self, query: str, limit: int = 5) -> Response: query_emb = self.encoder.encode_query(query) context = self.__context(query_emb, limit) - prompt = Prompt(query, context) - return self.generator.generate(prompt, role)["response"] + prompt = Prompt(query, "\n".join(context)) + answer = self.generator.generate(prompt)["response"] + return Response(query, context, answer) diff --git a/rag/ui.py b/rag/ui.py index 1e4dd64..277c084 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,8 +1,14 @@ import streamlit as st +from langchain_community.document_loaders.blob_loaders import Blob + +try: + from rag.rag import RAG +except ModuleNotFoundError: + from rag import RAG + +rag = RAG() -# from loguru import logger as log -# from rag.rag import RAG def upload_pdfs(): files = st.file_uploader( @@ -11,10 +17,35 @@ def upload_pdfs(): accept_multiple_files=True, ) for file in files: - bytes = file.read() - st.write(bytes) + blob = Blob.from_data(file.read()) + rag.add_pdf_from_blob(blob) if __name__ == "__main__": st.header("RAG-UI") + upload_pdfs() + query = st.text_area( + "query", + key="query", + height=100, + placeholder="Enter query here", + help="", + label_visibility="collapsed", + disabled=False, + ) + + (result_column, context_column) = st.columns(2) + + if query: + response = rag.retrive(query) + + with result_column: + st.markdown("### Answer") + st.markdown(response.answer) + + with context_column: + st.markdown("### Context") + for c in response.context: + st.markdown(c) + st.markdown("---") -- cgit v1.2.3-70-g09d2