diff options
-rw-r--r-- | notebooks/testing.ipynb | 58 | ||||
-rw-r--r-- | pyproject.toml | 1 | ||||
-rw-r--r-- | rag/cli.py | 28 | ||||
-rw-r--r-- | rag/db/document.py | 7 | ||||
-rw-r--r-- | rag/db/vector.py | 5 | ||||
-rw-r--r-- | rag/llm/encoder.py | 11 | ||||
-rw-r--r-- | rag/llm/generator.py | 7 | ||||
-rw-r--r-- | rag/main.py | 0 | ||||
-rw-r--r-- | rag/parser/pdf.py | 18 | ||||
-rw-r--r-- | rag/rag.py | 52 | ||||
-rw-r--r-- | rag/ui.py | 39 |
11 files changed, 150 insertions, 76 deletions
diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb index 2bd3dd6..36dd3af 100644 --- a/notebooks/testing.ipynb +++ b/notebooks/testing.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "c1f56ae3-a056-4b31-bcab-27c2c97c00f1", "metadata": {}, "outputs": [], @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "6b5cb12e-df7e-4532-b78b-216e11ed6161", "metadata": {}, "outputs": [], @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "b8382795-9610-4b24-80b7-31397b2faf90", "metadata": {}, "outputs": [ @@ -36,8 +36,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-06 01:37:11.913\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m26\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n", - "\u001b[32m2024-04-06 01:37:11.926\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m36\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n" + "\u001b[32m2024-04-07 21:14:48.364\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m36\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n", + "\u001b[32m2024-04-07 21:14:48.368\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m25\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n" ] } ], @@ -47,27 +47,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "ac57e50d-1fc3-4fc9-90e5-5bdb97bd2f5e", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-04-06 01:37:17.243\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36madd_document\u001b[0m:\u001b[36m37\u001b[0m - \u001b[34m\u001b[1mInserting document hash into documents db...\u001b[0m\n", - "\u001b[32m2024-04-06 01:37:17.244\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__hash\u001b[0m:\u001b[36m32\u001b[0m - \u001b[34m\u001b[1mGenerating sha256 hash for pdf document\u001b[0m\n", - "\u001b[32m2024-04-06 01:37:17.247\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.rag\u001b[0m:\u001b[36madd_pdf\u001b[0m:\u001b[36m31\u001b[0m - \u001b[34m\u001b[1mDocument already exists!\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ - "rag.add_pdf(path)" + "rag.add_pdf(path)\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "1c6b48d2-eb04-4a7c-8224-78aabfc7c887", "metadata": {}, "outputs": [], @@ -77,39 +67,39 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "a95c8250-00b2-4cbc-a9c6-a76d14ef2da5", "metadata": {}, + "outputs": [], + "source": [ + "rag.rag(query, \"quant researcher\", limit=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2c28db8c-c2bb-4092-b1d3-fd3f8bb060b5", + "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-04-06 01:37:17.265\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.llm.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m33\u001b[0m - \u001b[34m\u001b[1mEncoding query: What is a factor model?\u001b[0m\n", - "\u001b[32m2024-04-06 01:37:17.858\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m51\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n", - "\u001b[32m2024-04-06 01:37:17.864\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.rag\u001b[0m:\u001b[36m__context\u001b[0m:\u001b[36m35\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n", - "\u001b[32m2024-04-06 01:37:17.865\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.llm.generator\u001b[0m:\u001b[36mgenerate\u001b[0m:\u001b[36m32\u001b[0m - \u001b[34m\u001b[1mGenerating answer...\u001b[0m\n" - ] - }, - { "data": { "text/plain": [ - "'A factor model is a type of model used to explain the returns or movements of financial assets by decomposing them into two parts: systematic risk, which is driven by a small number of factors affecting many securities; and idiosyncratic risk, which is specific to individual stocks. The general factor model is rt = φ0 + h(ft) + wt, where rt denotes the return of an asset at time t, φ0 represents a constant vector, ft is a vector of factors responsible for most of the randomness in the market, and h is a function that summarizes how these low-dimensional factors affect higher-dimensional markets. The residual wt accounts for any remaining uncorrelated perturbations with only a marginal effect on returns.'" + "True" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "rag.rag(query, \"quant researcher\", limit=5)" + "rag.vector_db.client.delete_collection(\"knowledge-base\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "11ac0d46-8589-4700-abf7-7959afbf611c", + "id": "05f7068c-b4c6-47b2-ac62-79c021838500", "metadata": {}, "outputs": [], "source": [] diff --git a/pyproject.toml b/pyproject.toml index 1d611c1..daec64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.1.0" description = "" authors = ["Gustaf Rydholm <gustaf.rydholm@gmail.com>"] readme = "README.md" +packages = [{include = "rag"}] [tool.poetry.dependencies] python = "^3.11" diff --git a/rag/cli.py b/rag/cli.py new file mode 100644 index 0000000..5ea1a47 --- /dev/null +++ b/rag/cli.py @@ -0,0 +1,28 @@ +from pathlib import Path + + +try: + from rag.rag import RAG +except ModuleNotFoundError: + from rag import RAG + +if __name__ == "__main__": + rag = RAG() + + while True: + print("Retrieval Augmented Generation") + choice = input("1. add pdf from path\n2. Enter a query\n") + match choice: + case "1": + path = input("Enter the path to the pdf: ") + path = Path(path) + rag.add_pdf_from_path(path) + case "2": + query = input("Enter your query: ") + if query: + result = rag.retrive(query) + print("Answer: \n") + print(result.answer) + case _: + print("Invalid option!") + diff --git a/rag/db/document.py b/rag/db/document.py index 528a399..54ac451 100644 --- a/rag/db/document.py +++ b/rag/db/document.py @@ -1,6 +1,7 @@ import hashlib import os +from langchain_community.document_loaders.blob_loaders import Blob import psycopg from loguru import logger as log @@ -26,11 +27,11 @@ class DocumentDB: cur.execute(TABLES) self.conn.commit() - def __hash(self, blob: bytes) -> str: + def __hash(self, blob: Blob) -> str: log.debug("Hashing document...") - return hashlib.sha256(blob).hexdigest() + return hashlib.sha256(blob.as_bytes()).hexdigest() - def add(self, blob: bytes) -> bool: + def add(self, blob: Blob) -> bool: with self.conn.cursor() as cur: hash = self.__hash(blob) cur.execute( diff --git a/rag/db/vector.py b/rag/db/vector.py index 4aa62cc..bbbbf32 100644 --- a/rag/db/vector.py +++ b/rag/db/vector.py @@ -50,6 +50,9 @@ class VectorDB: def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: log.debug("Searching for vectors...") hits = self.client.search( - collection_name=self.collection_name, query_vector=query, limit=limit + collection_name=self.collection_name, + query_vector=query, + limit=limit, + score_threshold=0.6, ) return hits diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index 7c03dc5..95f3c6a 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import Iterator, List from uuid import uuid4 import ollama @@ -7,8 +7,11 @@ from langchain_core.documents import Document from loguru import logger as log from qdrant_client.http.models import StrictFloat -from rag.db.vector import Point +try: + from rag.db.vector import Point +except ModuleNotFoundError: + from db.vector import Point class Encoder: def __init__(self) -> None: @@ -18,11 +21,11 @@ class Encoder: def __encode(self, prompt: str) -> List[StrictFloat]: return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - def encode_document(self, chunks: List[Document]) -> List[Point]: + def encode_document(self, chunks: Iterator[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( - id=str(uuid4()), + id=uuid4().hex, vector=self.__encode(chunk.page_content), payload={"text": chunk.page_content}, ) diff --git a/rag/llm/generator.py b/rag/llm/generator.py index b0c6c40..8c7702f 100644 --- a/rag/llm/generator.py +++ b/rag/llm/generator.py @@ -15,9 +15,8 @@ class Generator: def __init__(self) -> None: self.model = os.environ["GENERATOR_MODEL"] - def __metaprompt(self, role: str, prompt: Prompt) -> str: + def __metaprompt(self, prompt: Prompt) -> str: metaprompt = ( - f"You are a {role}.\n" "Answer the following question using the provided context.\n" "If you can't find the answer, do not pretend you know it," 'but answer "I don\'t know".\n\n' @@ -28,7 +27,7 @@ class Generator: ) return metaprompt - def generate(self, prompt: Prompt, role: str) -> str: + def generate(self, prompt: Prompt) -> str: log.debug("Generating answer...") - metaprompt = self.__metaprompt(role, prompt) + metaprompt = self.__metaprompt(prompt) return ollama.generate(model=self.model, prompt=metaprompt) diff --git a/rag/main.py b/rag/main.py deleted file mode 100644 index e69de29..0000000 --- a/rag/main.py +++ /dev/null diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py index ed4dc8b..cbd86a3 100644 --- a/rag/parser/pdf.py +++ b/rag/parser/pdf.py @@ -1,28 +1,24 @@ import os from pathlib import Path -from typing import Iterator, Optional +from typing import Iterator from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from langchain_community.document_loaders.parsers.pdf import ( PyPDFParser, ) -from rag.db.document import DocumentDB +from langchain_community.document_loaders.blob_loaders import Blob -class PDF: +class PDFParser: def __init__(self) -> None: - self.db = DocumentDB() self.parser = PyPDFParser(password=None, extract_images=False) - def from_data(self, blob) -> Optional[Iterator[Document]]: - if self.db.add(blob): - yield from self.parser.parse(blob) - yield None + def from_data(self, blob: Blob) -> Iterator[Document]: + yield from self.parser.parse(blob) - def from_path(self, file_path: Path) -> Optional[Iterator[Document]]: - blob = Blob.from_path(file_path) - from_data(blob) + def from_path(self, path: Path) -> Iterator[Document]: + return Blob.from_path(path) def chunk(self, content: Iterator[Document]): splitter = RecursiveCharacterTextSplitter( @@ -1,3 +1,5 @@ +from dataclasses import dataclass +from io import BytesIO from pathlib import Path from typing import List @@ -5,27 +7,46 @@ from dotenv import load_dotenv from loguru import logger as log from qdrant_client.models import StrictFloat -from rag.db.document import DocumentDB -from rag.db.vector import VectorDB -from rag.llm.encoder import Encoder -from rag.llm.generator import Generator, Prompt -from rag.parser import pdf + +try: + from rag.db.vector import VectorDB + from rag.db.document import DocumentDB + from rag.llm.encoder import Encoder + from rag.llm.generator import Generator, Prompt + from rag.parser.pdf import PDFParser +except ModuleNotFoundError: + from db.vector import VectorDB + from db.document import DocumentDB + from llm.encoder import Encoder + from llm.generator import Generator, Prompt + from parser.pdf import PDFParser + + +@dataclass +class Response: + query: str + context: List[str] + answer: str class RAG: def __init__(self) -> None: # FIXME: load this somewhere else? load_dotenv() + self.pdf_parser = PDFParser() self.generator = Generator() self.encoder = Encoder() self.vector_db = VectorDB() + self.doc_db = DocumentDB() + + def add_pdf_from_path(self, path: Path): + blob = self.pdf_parser.from_path(path) + self.add_pdf_from_blob(blob) - # FIXME: refactor this, add vector? - def add_pdf(self, filepath: Path): - chunks = pdf.parser(filepath) - added = self.document_db.add(chunks) - if added: - log.debug(f"Adding pdf with filepath: {filepath} to vector db") + def add_pdf_from_blob(self, blob: BytesIO): + if self.doc_db.add(blob): + log.debug("Adding pdf to vector database...") + chunks = self.pdf_parser.from_data(blob) points = self.encoder.encode_document(chunks) self.vector_db.add(points) else: @@ -34,10 +55,11 @@ class RAG: def __context(self, query_emb: List[StrictFloat], limit: int) -> str: hits = self.vector_db.search(query_emb, limit) log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return "\n".join(h.payload["text"] for h in hits) + return [h.payload["text"] for h in hits] - def rag(self, query: str, role: str, limit: int = 5) -> str: + def retrive(self, query: str, limit: int = 5) -> Response: query_emb = self.encoder.encode_query(query) context = self.__context(query_emb, limit) - prompt = Prompt(query, context) - return self.generator.generate(prompt, role)["response"] + prompt = Prompt(query, "\n".join(context)) + answer = self.generator.generate(prompt)["response"] + return Response(query, context, answer) @@ -1,8 +1,14 @@ import streamlit as st +from langchain_community.document_loaders.blob_loaders import Blob + +try: + from rag.rag import RAG +except ModuleNotFoundError: + from rag import RAG + +rag = RAG() -# from loguru import logger as log -# from rag.rag import RAG def upload_pdfs(): files = st.file_uploader( @@ -11,10 +17,35 @@ def upload_pdfs(): accept_multiple_files=True, ) for file in files: - bytes = file.read() - st.write(bytes) + blob = Blob.from_data(file.read()) + rag.add_pdf_from_blob(blob) if __name__ == "__main__": st.header("RAG-UI") + upload_pdfs() + query = st.text_area( + "query", + key="query", + height=100, + placeholder="Enter query here", + help="", + label_visibility="collapsed", + disabled=False, + ) + + (result_column, context_column) = st.columns(2) + + if query: + response = rag.retrive(query) + + with result_column: + st.markdown("### Answer") + st.markdown(response.answer) + + with context_column: + st.markdown("### Context") + for c in response.context: + st.markdown(c) + st.markdown("---") |