diff options
| -rw-r--r-- | notebooks/testing.ipynb | 58 | ||||
| -rw-r--r-- | pyproject.toml | 1 | ||||
| -rw-r--r-- | rag/cli.py | 28 | ||||
| -rw-r--r-- | rag/db/document.py | 7 | ||||
| -rw-r--r-- | rag/db/vector.py | 5 | ||||
| -rw-r--r-- | rag/llm/encoder.py | 11 | ||||
| -rw-r--r-- | rag/llm/generator.py | 7 | ||||
| -rw-r--r-- | rag/main.py | 0 | ||||
| -rw-r--r-- | rag/parser/pdf.py | 18 | ||||
| -rw-r--r-- | rag/rag.py | 52 | ||||
| -rw-r--r-- | rag/ui.py | 39 | 
11 files changed, 150 insertions, 76 deletions
| diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb index 2bd3dd6..36dd3af 100644 --- a/notebooks/testing.ipynb +++ b/notebooks/testing.ipynb @@ -2,7 +2,7 @@   "cells": [    {     "cell_type": "code", -   "execution_count": 1, +   "execution_count": 2,     "id": "c1f56ae3-a056-4b31-bcab-27c2c97c00f1",     "metadata": {},     "outputs": [], @@ -18,7 +18,7 @@    },    {     "cell_type": "code", -   "execution_count": 2, +   "execution_count": 3,     "id": "6b5cb12e-df7e-4532-b78b-216e11ed6161",     "metadata": {},     "outputs": [], @@ -28,7 +28,7 @@    },    {     "cell_type": "code", -   "execution_count": 3, +   "execution_count": 4,     "id": "b8382795-9610-4b24-80b7-31397b2faf90",     "metadata": {},     "outputs": [ @@ -36,8 +36,8 @@       "name": "stderr",       "output_type": "stream",       "text": [ -      "\u001b[32m2024-04-06 01:37:11.913\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m26\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n", -      "\u001b[32m2024-04-06 01:37:11.926\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m36\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n" +      "\u001b[32m2024-04-07 21:14:48.364\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m36\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n", +      "\u001b[32m2024-04-07 21:14:48.368\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m25\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n"       ]      }     ], @@ -47,27 +47,17 @@    },    {     "cell_type": "code", -   "execution_count": 4, +   "execution_count": null,     "id": "ac57e50d-1fc3-4fc9-90e5-5bdb97bd2f5e",     "metadata": {}, -   "outputs": [ -    { -     "name": "stderr", -     "output_type": "stream", -     "text": [ -      "\u001b[32m2024-04-06 01:37:17.243\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36madd_document\u001b[0m:\u001b[36m37\u001b[0m - \u001b[34m\u001b[1mInserting document hash into documents db...\u001b[0m\n", -      "\u001b[32m2024-04-06 01:37:17.244\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__hash\u001b[0m:\u001b[36m32\u001b[0m - \u001b[34m\u001b[1mGenerating sha256 hash for pdf document\u001b[0m\n", -      "\u001b[32m2024-04-06 01:37:17.247\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.rag\u001b[0m:\u001b[36madd_pdf\u001b[0m:\u001b[36m31\u001b[0m - \u001b[34m\u001b[1mDocument already exists!\u001b[0m\n" -     ] -    } -   ], +   "outputs": [],     "source": [ -    "rag.add_pdf(path)" +    "rag.add_pdf(path)\n"     ]    },    {     "cell_type": "code", -   "execution_count": 5, +   "execution_count": null,     "id": "1c6b48d2-eb04-4a7c-8224-78aabfc7c887",     "metadata": {},     "outputs": [], @@ -77,39 +67,39 @@    },    {     "cell_type": "code", -   "execution_count": 6, +   "execution_count": null,     "id": "a95c8250-00b2-4cbc-a9c6-a76d14ef2da5",     "metadata": {}, +   "outputs": [], +   "source": [ +    "rag.rag(query, \"quant researcher\", limit=5)" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 8, +   "id": "2c28db8c-c2bb-4092-b1d3-fd3f8bb060b5", +   "metadata": {},     "outputs": [      { -     "name": "stderr", -     "output_type": "stream", -     "text": [ -      "\u001b[32m2024-04-06 01:37:17.265\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.llm.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m33\u001b[0m - \u001b[34m\u001b[1mEncoding query: What is a factor model?\u001b[0m\n", -      "\u001b[32m2024-04-06 01:37:17.858\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m51\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n", -      "\u001b[32m2024-04-06 01:37:17.864\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.rag\u001b[0m:\u001b[36m__context\u001b[0m:\u001b[36m35\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n", -      "\u001b[32m2024-04-06 01:37:17.865\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mrag.llm.generator\u001b[0m:\u001b[36mgenerate\u001b[0m:\u001b[36m32\u001b[0m - \u001b[34m\u001b[1mGenerating answer...\u001b[0m\n" -     ] -    }, -    {       "data": {        "text/plain": [ -       "'A factor model is a type of model used to explain the returns or movements of financial assets by decomposing them into two parts: systematic risk, which is driven by a small number of factors affecting many securities; and idiosyncratic risk, which is specific to individual stocks. The general factor model is rt = φ0 + h(ft) + wt, where rt denotes the return of an asset at time t, φ0 represents a constant vector, ft is a vector of factors responsible for most of the randomness in the market, and h is a function that summarizes how these low-dimensional factors affect higher-dimensional markets. The residual wt accounts for any remaining uncorrelated perturbations with only a marginal effect on returns.'" +       "True"        ]       }, -     "execution_count": 6, +     "execution_count": 8,       "metadata": {},       "output_type": "execute_result"      }     ],     "source": [ -    "rag.rag(query, \"quant researcher\", limit=5)" +    "rag.vector_db.client.delete_collection(\"knowledge-base\")"     ]    },    {     "cell_type": "code",     "execution_count": null, -   "id": "11ac0d46-8589-4700-abf7-7959afbf611c", +   "id": "05f7068c-b4c6-47b2-ac62-79c021838500",     "metadata": {},     "outputs": [],     "source": [] diff --git a/pyproject.toml b/pyproject.toml index 1d611c1..daec64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.1.0"  description = ""  authors = ["Gustaf Rydholm <gustaf.rydholm@gmail.com>"]  readme = "README.md" +packages = [{include = "rag"}]  [tool.poetry.dependencies]  python = "^3.11" diff --git a/rag/cli.py b/rag/cli.py new file mode 100644 index 0000000..5ea1a47 --- /dev/null +++ b/rag/cli.py @@ -0,0 +1,28 @@ +from pathlib import Path + + +try: +    from rag.rag import RAG +except ModuleNotFoundError: +    from rag import RAG + +if __name__ == "__main__": +    rag = RAG() + +    while True: +        print("Retrieval Augmented Generation") +        choice = input("1. add pdf from path\n2. Enter a query\n") +        match choice: +            case "1": +                path = input("Enter the path to the pdf: ") +                path = Path(path) +                rag.add_pdf_from_path(path) +            case "2": +                query = input("Enter your query: ") +                if query: +                    result = rag.retrive(query) +                    print("Answer: \n") +                    print(result.answer) +            case _: +                print("Invalid option!") + diff --git a/rag/db/document.py b/rag/db/document.py index 528a399..54ac451 100644 --- a/rag/db/document.py +++ b/rag/db/document.py @@ -1,6 +1,7 @@  import hashlib  import os +from langchain_community.document_loaders.blob_loaders import Blob  import psycopg  from loguru import logger as log @@ -26,11 +27,11 @@ class DocumentDB:              cur.execute(TABLES)              self.conn.commit() -    def __hash(self, blob: bytes) -> str: +    def __hash(self, blob: Blob) -> str:          log.debug("Hashing document...") -        return hashlib.sha256(blob).hexdigest() +        return hashlib.sha256(blob.as_bytes()).hexdigest() -    def add(self, blob: bytes) -> bool: +    def add(self, blob: Blob) -> bool:          with self.conn.cursor() as cur:              hash = self.__hash(blob)              cur.execute( diff --git a/rag/db/vector.py b/rag/db/vector.py index 4aa62cc..bbbbf32 100644 --- a/rag/db/vector.py +++ b/rag/db/vector.py @@ -50,6 +50,9 @@ class VectorDB:      def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:          log.debug("Searching for vectors...")          hits = self.client.search( -            collection_name=self.collection_name, query_vector=query, limit=limit +            collection_name=self.collection_name, +            query_vector=query, +            limit=limit, +            score_threshold=0.6,          )          return hits diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index 7c03dc5..95f3c6a 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,5 +1,5 @@  import os -from typing import List +from typing import Iterator, List  from uuid import uuid4  import ollama @@ -7,8 +7,11 @@ from langchain_core.documents import Document  from loguru import logger as log  from qdrant_client.http.models import StrictFloat -from rag.db.vector import Point +try: +    from rag.db.vector import Point +except ModuleNotFoundError: +    from db.vector import Point  class Encoder:      def __init__(self) -> None: @@ -18,11 +21,11 @@ class Encoder:      def __encode(self, prompt: str) -> List[StrictFloat]:          return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) -    def encode_document(self, chunks: List[Document]) -> List[Point]: +    def encode_document(self, chunks: Iterator[Document]) -> List[Point]:          log.debug("Encoding document...")          return [              Point( -                id=str(uuid4()), +                id=uuid4().hex,                  vector=self.__encode(chunk.page_content),                  payload={"text": chunk.page_content},              ) diff --git a/rag/llm/generator.py b/rag/llm/generator.py index b0c6c40..8c7702f 100644 --- a/rag/llm/generator.py +++ b/rag/llm/generator.py @@ -15,9 +15,8 @@ class Generator:      def __init__(self) -> None:          self.model = os.environ["GENERATOR_MODEL"] -    def __metaprompt(self, role: str, prompt: Prompt) -> str: +    def __metaprompt(self, prompt: Prompt) -> str:          metaprompt = ( -            f"You are a {role}.\n"              "Answer the following question using the provided context.\n"              "If you can't find the answer, do not pretend you know it,"              'but answer "I don\'t know".\n\n' @@ -28,7 +27,7 @@ class Generator:          )          return metaprompt -    def generate(self, prompt: Prompt, role: str) -> str: +    def generate(self, prompt: Prompt) -> str:          log.debug("Generating answer...") -        metaprompt = self.__metaprompt(role, prompt) +        metaprompt = self.__metaprompt(prompt)          return ollama.generate(model=self.model, prompt=metaprompt) diff --git a/rag/main.py b/rag/main.py deleted file mode 100644 index e69de29..0000000 --- a/rag/main.py +++ /dev/null diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py index ed4dc8b..cbd86a3 100644 --- a/rag/parser/pdf.py +++ b/rag/parser/pdf.py @@ -1,28 +1,24 @@  import os  from pathlib import Path -from typing import Iterator, Optional +from typing import Iterator  from langchain.text_splitter import RecursiveCharacterTextSplitter  from langchain_core.documents import Document  from langchain_community.document_loaders.parsers.pdf import (      PyPDFParser,  ) -from rag.db.document import DocumentDB +from langchain_community.document_loaders.blob_loaders import Blob -class PDF: +class PDFParser:      def __init__(self) -> None: -        self.db = DocumentDB()          self.parser = PyPDFParser(password=None, extract_images=False) -    def from_data(self, blob) -> Optional[Iterator[Document]]: -        if self.db.add(blob): -            yield from self.parser.parse(blob) -        yield None +    def from_data(self, blob: Blob) -> Iterator[Document]: +        yield from self.parser.parse(blob) -    def from_path(self, file_path: Path) -> Optional[Iterator[Document]]: -        blob = Blob.from_path(file_path) -        from_data(blob) +    def from_path(self, path: Path) -> Iterator[Document]: +        return Blob.from_path(path)      def chunk(self, content: Iterator[Document]):          splitter = RecursiveCharacterTextSplitter( @@ -1,3 +1,5 @@ +from dataclasses import dataclass +from io import BytesIO  from pathlib import Path  from typing import List @@ -5,27 +7,46 @@ from dotenv import load_dotenv  from loguru import logger as log  from qdrant_client.models import StrictFloat -from rag.db.document import DocumentDB -from rag.db.vector import VectorDB -from rag.llm.encoder import Encoder -from rag.llm.generator import Generator, Prompt -from rag.parser import pdf + +try: +    from rag.db.vector import VectorDB +    from rag.db.document import DocumentDB +    from rag.llm.encoder import Encoder +    from rag.llm.generator import Generator, Prompt +    from rag.parser.pdf import PDFParser +except ModuleNotFoundError: +    from db.vector import VectorDB +    from db.document import DocumentDB +    from llm.encoder import Encoder +    from llm.generator import Generator, Prompt +    from parser.pdf import PDFParser + + +@dataclass +class Response: +    query: str +    context: List[str] +    answer: str  class RAG:      def __init__(self) -> None:          # FIXME: load this somewhere else?          load_dotenv() +        self.pdf_parser = PDFParser()          self.generator = Generator()          self.encoder = Encoder()          self.vector_db = VectorDB() +        self.doc_db = DocumentDB() + +    def add_pdf_from_path(self, path: Path): +        blob = self.pdf_parser.from_path(path) +        self.add_pdf_from_blob(blob) -    # FIXME: refactor this, add vector? -    def add_pdf(self, filepath: Path): -        chunks = pdf.parser(filepath) -        added = self.document_db.add(chunks) -        if added: -            log.debug(f"Adding pdf with filepath: {filepath} to vector db") +    def add_pdf_from_blob(self, blob: BytesIO): +        if self.doc_db.add(blob): +            log.debug("Adding pdf to vector database...") +            chunks = self.pdf_parser.from_data(blob)              points = self.encoder.encode_document(chunks)              self.vector_db.add(points)          else: @@ -34,10 +55,11 @@ class RAG:      def __context(self, query_emb: List[StrictFloat], limit: int) -> str:          hits = self.vector_db.search(query_emb, limit)          log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") -        return "\n".join(h.payload["text"] for h in hits) +        return [h.payload["text"] for h in hits] -    def rag(self, query: str, role: str, limit: int = 5) -> str: +    def retrive(self, query: str, limit: int = 5) -> Response:          query_emb = self.encoder.encode_query(query)          context = self.__context(query_emb, limit) -        prompt = Prompt(query, context) -        return self.generator.generate(prompt, role)["response"] +        prompt = Prompt(query, "\n".join(context)) +        answer = self.generator.generate(prompt)["response"] +        return Response(query, context, answer) @@ -1,8 +1,14 @@  import streamlit as st +from langchain_community.document_loaders.blob_loaders import Blob + +try: +    from rag.rag import RAG +except ModuleNotFoundError: +    from rag import RAG + +rag = RAG() -# from loguru import logger as log -# from rag.rag import RAG  def upload_pdfs():      files = st.file_uploader( @@ -11,10 +17,35 @@ def upload_pdfs():          accept_multiple_files=True,      )      for file in files: -        bytes = file.read() -        st.write(bytes) +        blob = Blob.from_data(file.read()) +        rag.add_pdf_from_blob(blob)  if __name__ == "__main__":      st.header("RAG-UI") +      upload_pdfs() +    query = st.text_area( +        "query", +        key="query", +        height=100, +        placeholder="Enter query here", +        help="", +        label_visibility="collapsed", +        disabled=False, +    ) + +    (result_column, context_column) = st.columns(2) + +    if query: +        response = rag.retrive(query) + +        with result_column: +            st.markdown("### Answer") +            st.markdown(response.answer) + +        with context_column: +            st.markdown("### Context") +            for c in response.context: +                st.markdown(c) +                st.markdown("---") |