summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-08 00:23:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-08 00:23:52 +0200
commit8211705debf9d1335223c606275f46c43c78d8a2 (patch)
treef09f902c7d31b2035813c42cbb4a47e720fa363b
parent95f47c4900a96d91daaef93bf87094ed3d4da43c (diff)
Updates
-rw-r--r--notebooks/testing.ipynb58
-rw-r--r--pyproject.toml1
-rw-r--r--rag/cli.py28
-rw-r--r--rag/db/document.py7
-rw-r--r--rag/db/vector.py5
-rw-r--r--rag/llm/encoder.py11
-rw-r--r--rag/llm/generator.py7
-rw-r--r--rag/main.py0
-rw-r--r--rag/parser/pdf.py18
-rw-r--r--rag/rag.py52
-rw-r--r--rag/ui.py39
11 files changed, 150 insertions, 76 deletions
diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb
index 2bd3dd6..36dd3af 100644
--- a/notebooks/testing.ipynb
+++ b/notebooks/testing.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 2,
"id": "c1f56ae3-a056-4b31-bcab-27c2c97c00f1",
"metadata": {},
"outputs": [],
@@ -18,7 +18,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"id": "6b5cb12e-df7e-4532-b78b-216e11ed6161",
"metadata": {},
"outputs": [],
@@ -28,7 +28,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "b8382795-9610-4b24-80b7-31397b2faf90",
"metadata": {},
"outputs": [
@@ -36,8 +36,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "\u001b[32m2024-04-06 01:37:11.913\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m26\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n",
- "\u001b[32m2024-04-06 01:37:11.926\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m36\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n"
+ "\u001b[32m2024-04-07 21:14:48.364\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m36\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n",
+ "\u001b[32m2024-04-07 21:14:48.368\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m25\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n"
]
}
],
@@ -47,27 +47,17 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "ac57e50d-1fc3-4fc9-90e5-5bdb97bd2f5e",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m2024-04-06 01:37:17.243\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36madd_document\u001b[0m:\u001b[36m37\u001b[0m - \u001b[34m\u001b[1mInserting document hash into documents db...\u001b[0m\n",
- "\u001b[32m2024-04-06 01:37:17.244\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.document\u001b[0m:\u001b[36m__hash\u001b[0m:\u001b[36m32\u001b[0m - \u001b[34m\u001b[1mGenerating sha256 hash for pdf document\u001b[0m\n",
- "\u001b[32m2024-04-06 01:37:17.247\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.rag\u001b[0m:\u001b[36madd_pdf\u001b[0m:\u001b[36m31\u001b[0m - \u001b[34m\u001b[1mDocument already exists!\u001b[0m\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "rag.add_pdf(path)"
+ "rag.add_pdf(path)\n"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "1c6b48d2-eb04-4a7c-8224-78aabfc7c887",
"metadata": {},
"outputs": [],
@@ -77,39 +67,39 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "a95c8250-00b2-4cbc-a9c6-a76d14ef2da5",
"metadata": {},
+ "outputs": [],
+ "source": [
+ "rag.rag(query, \"quant researcher\", limit=5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "2c28db8c-c2bb-4092-b1d3-fd3f8bb060b5",
+ "metadata": {},
"outputs": [
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m2024-04-06 01:37:17.265\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.llm.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m33\u001b[0m - \u001b[34m\u001b[1mEncoding query: What is a factor model?\u001b[0m\n",
- "\u001b[32m2024-04-06 01:37:17.858\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.db.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m51\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n",
- "\u001b[32m2024-04-06 01:37:17.864\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.rag\u001b[0m:\u001b[36m__context\u001b[0m:\u001b[36m35\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n",
- "\u001b[32m2024-04-06 01:37:17.865\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.llm.generator\u001b[0m:\u001b[36mgenerate\u001b[0m:\u001b[36m32\u001b[0m - \u001b[34m\u001b[1mGenerating answer...\u001b[0m\n"
- ]
- },
- {
"data": {
"text/plain": [
- "'A factor model is a type of model used to explain the returns or movements of financial assets by decomposing them into two parts: systematic risk, which is driven by a small number of factors affecting many securities; and idiosyncratic risk, which is specific to individual stocks. The general factor model is rt = φ0 + h(ft) + wt, where rt denotes the return of an asset at time t, φ0 represents a constant vector, ft is a vector of factors responsible for most of the randomness in the market, and h is a function that summarizes how these low-dimensional factors affect higher-dimensional markets. The residual wt accounts for any remaining uncorrelated perturbations with only a marginal effect on returns.'"
+ "True"
]
},
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "rag.rag(query, \"quant researcher\", limit=5)"
+ "rag.vector_db.client.delete_collection(\"knowledge-base\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "11ac0d46-8589-4700-abf7-7959afbf611c",
+ "id": "05f7068c-b4c6-47b2-ac62-79c021838500",
"metadata": {},
"outputs": [],
"source": []
diff --git a/pyproject.toml b/pyproject.toml
index 1d611c1..daec64e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,6 +4,7 @@ version = "0.1.0"
description = ""
authors = ["Gustaf Rydholm <gustaf.rydholm@gmail.com>"]
readme = "README.md"
+packages = [{include = "rag"}]
[tool.poetry.dependencies]
python = "^3.11"
diff --git a/rag/cli.py b/rag/cli.py
new file mode 100644
index 0000000..5ea1a47
--- /dev/null
+++ b/rag/cli.py
@@ -0,0 +1,28 @@
+from pathlib import Path
+
+
+try:
+ from rag.rag import RAG
+except ModuleNotFoundError:
+ from rag import RAG
+
+if __name__ == "__main__":
+ rag = RAG()
+
+ while True:
+ print("Retrieval Augmented Generation")
+ choice = input("1. add pdf from path\n2. Enter a query\n")
+ match choice:
+ case "1":
+ path = input("Enter the path to the pdf: ")
+ path = Path(path)
+ rag.add_pdf_from_path(path)
+ case "2":
+ query = input("Enter your query: ")
+ if query:
+ result = rag.retrive(query)
+ print("Answer: \n")
+ print(result.answer)
+ case _:
+ print("Invalid option!")
+
diff --git a/rag/db/document.py b/rag/db/document.py
index 528a399..54ac451 100644
--- a/rag/db/document.py
+++ b/rag/db/document.py
@@ -1,6 +1,7 @@
import hashlib
import os
+from langchain_community.document_loaders.blob_loaders import Blob
import psycopg
from loguru import logger as log
@@ -26,11 +27,11 @@ class DocumentDB:
cur.execute(TABLES)
self.conn.commit()
- def __hash(self, blob: bytes) -> str:
+ def __hash(self, blob: Blob) -> str:
log.debug("Hashing document...")
- return hashlib.sha256(blob).hexdigest()
+ return hashlib.sha256(blob.as_bytes()).hexdigest()
- def add(self, blob: bytes) -> bool:
+ def add(self, blob: Blob) -> bool:
with self.conn.cursor() as cur:
hash = self.__hash(blob)
cur.execute(
diff --git a/rag/db/vector.py b/rag/db/vector.py
index 4aa62cc..bbbbf32 100644
--- a/rag/db/vector.py
+++ b/rag/db/vector.py
@@ -50,6 +50,9 @@ class VectorDB:
def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:
log.debug("Searching for vectors...")
hits = self.client.search(
- collection_name=self.collection_name, query_vector=query, limit=limit
+ collection_name=self.collection_name,
+ query_vector=query,
+ limit=limit,
+ score_threshold=0.6,
)
return hits
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index 7c03dc5..95f3c6a 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,5 +1,5 @@
import os
-from typing import List
+from typing import Iterator, List
from uuid import uuid4
import ollama
@@ -7,8 +7,11 @@ from langchain_core.documents import Document
from loguru import logger as log
from qdrant_client.http.models import StrictFloat
-from rag.db.vector import Point
+try:
+ from rag.db.vector import Point
+except ModuleNotFoundError:
+ from db.vector import Point
class Encoder:
def __init__(self) -> None:
@@ -18,11 +21,11 @@ class Encoder:
def __encode(self, prompt: str) -> List[StrictFloat]:
return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
- def encode_document(self, chunks: List[Document]) -> List[Point]:
+ def encode_document(self, chunks: Iterator[Document]) -> List[Point]:
log.debug("Encoding document...")
return [
Point(
- id=str(uuid4()),
+ id=uuid4().hex,
vector=self.__encode(chunk.page_content),
payload={"text": chunk.page_content},
)
diff --git a/rag/llm/generator.py b/rag/llm/generator.py
index b0c6c40..8c7702f 100644
--- a/rag/llm/generator.py
+++ b/rag/llm/generator.py
@@ -15,9 +15,8 @@ class Generator:
def __init__(self) -> None:
self.model = os.environ["GENERATOR_MODEL"]
- def __metaprompt(self, role: str, prompt: Prompt) -> str:
+ def __metaprompt(self, prompt: Prompt) -> str:
metaprompt = (
- f"You are a {role}.\n"
"Answer the following question using the provided context.\n"
"If you can't find the answer, do not pretend you know it,"
'but answer "I don\'t know".\n\n'
@@ -28,7 +27,7 @@ class Generator:
)
return metaprompt
- def generate(self, prompt: Prompt, role: str) -> str:
+ def generate(self, prompt: Prompt) -> str:
log.debug("Generating answer...")
- metaprompt = self.__metaprompt(role, prompt)
+ metaprompt = self.__metaprompt(prompt)
return ollama.generate(model=self.model, prompt=metaprompt)
diff --git a/rag/main.py b/rag/main.py
deleted file mode 100644
index e69de29..0000000
--- a/rag/main.py
+++ /dev/null
diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py
index ed4dc8b..cbd86a3 100644
--- a/rag/parser/pdf.py
+++ b/rag/parser/pdf.py
@@ -1,28 +1,24 @@
import os
from pathlib import Path
-from typing import Iterator, Optional
+from typing import Iterator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.document_loaders.parsers.pdf import (
PyPDFParser,
)
-from rag.db.document import DocumentDB
+from langchain_community.document_loaders.blob_loaders import Blob
-class PDF:
+class PDFParser:
def __init__(self) -> None:
- self.db = DocumentDB()
self.parser = PyPDFParser(password=None, extract_images=False)
- def from_data(self, blob) -> Optional[Iterator[Document]]:
- if self.db.add(blob):
- yield from self.parser.parse(blob)
- yield None
+ def from_data(self, blob: Blob) -> Iterator[Document]:
+ yield from self.parser.parse(blob)
- def from_path(self, file_path: Path) -> Optional[Iterator[Document]]:
- blob = Blob.from_path(file_path)
- from_data(blob)
+ def from_path(self, path: Path) -> Iterator[Document]:
+ return Blob.from_path(path)
def chunk(self, content: Iterator[Document]):
splitter = RecursiveCharacterTextSplitter(
diff --git a/rag/rag.py b/rag/rag.py
index 488e30a..cd4537e 100644
--- a/rag/rag.py
+++ b/rag/rag.py
@@ -1,3 +1,5 @@
+from dataclasses import dataclass
+from io import BytesIO
from pathlib import Path
from typing import List
@@ -5,27 +7,46 @@ from dotenv import load_dotenv
from loguru import logger as log
from qdrant_client.models import StrictFloat
-from rag.db.document import DocumentDB
-from rag.db.vector import VectorDB
-from rag.llm.encoder import Encoder
-from rag.llm.generator import Generator, Prompt
-from rag.parser import pdf
+
+try:
+ from rag.db.vector import VectorDB
+ from rag.db.document import DocumentDB
+ from rag.llm.encoder import Encoder
+ from rag.llm.generator import Generator, Prompt
+ from rag.parser.pdf import PDFParser
+except ModuleNotFoundError:
+ from db.vector import VectorDB
+ from db.document import DocumentDB
+ from llm.encoder import Encoder
+ from llm.generator import Generator, Prompt
+ from parser.pdf import PDFParser
+
+
+@dataclass
+class Response:
+ query: str
+ context: List[str]
+ answer: str
class RAG:
def __init__(self) -> None:
# FIXME: load this somewhere else?
load_dotenv()
+ self.pdf_parser = PDFParser()
self.generator = Generator()
self.encoder = Encoder()
self.vector_db = VectorDB()
+ self.doc_db = DocumentDB()
+
+ def add_pdf_from_path(self, path: Path):
+ blob = self.pdf_parser.from_path(path)
+ self.add_pdf_from_blob(blob)
- # FIXME: refactor this, add vector?
- def add_pdf(self, filepath: Path):
- chunks = pdf.parser(filepath)
- added = self.document_db.add(chunks)
- if added:
- log.debug(f"Adding pdf with filepath: {filepath} to vector db")
+ def add_pdf_from_blob(self, blob: BytesIO):
+ if self.doc_db.add(blob):
+ log.debug("Adding pdf to vector database...")
+ chunks = self.pdf_parser.from_data(blob)
points = self.encoder.encode_document(chunks)
self.vector_db.add(points)
else:
@@ -34,10 +55,11 @@ class RAG:
def __context(self, query_emb: List[StrictFloat], limit: int) -> str:
hits = self.vector_db.search(query_emb, limit)
log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
- return "\n".join(h.payload["text"] for h in hits)
+ return [h.payload["text"] for h in hits]
- def rag(self, query: str, role: str, limit: int = 5) -> str:
+ def retrive(self, query: str, limit: int = 5) -> Response:
query_emb = self.encoder.encode_query(query)
context = self.__context(query_emb, limit)
- prompt = Prompt(query, context)
- return self.generator.generate(prompt, role)["response"]
+ prompt = Prompt(query, "\n".join(context))
+ answer = self.generator.generate(prompt)["response"]
+ return Response(query, context, answer)
diff --git a/rag/ui.py b/rag/ui.py
index 1e4dd64..277c084 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,8 +1,14 @@
import streamlit as st
+from langchain_community.document_loaders.blob_loaders import Blob
+
+try:
+ from rag.rag import RAG
+except ModuleNotFoundError:
+ from rag import RAG
+
+rag = RAG()
-# from loguru import logger as log
-# from rag.rag import RAG
def upload_pdfs():
files = st.file_uploader(
@@ -11,10 +17,35 @@ def upload_pdfs():
accept_multiple_files=True,
)
for file in files:
- bytes = file.read()
- st.write(bytes)
+ blob = Blob.from_data(file.read())
+ rag.add_pdf_from_blob(blob)
if __name__ == "__main__":
st.header("RAG-UI")
+
upload_pdfs()
+ query = st.text_area(
+ "query",
+ key="query",
+ height=100,
+ placeholder="Enter query here",
+ help="",
+ label_visibility="collapsed",
+ disabled=False,
+ )
+
+ (result_column, context_column) = st.columns(2)
+
+ if query:
+ response = rag.retrive(query)
+
+ with result_column:
+ st.markdown("### Answer")
+ st.markdown(response.answer)
+
+ with context_column:
+ st.markdown("### Context")
+ for c in response.context:
+ st.markdown(c)
+ st.markdown("---")