summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-03-22 22:21:19 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-03-22 22:21:19 +0100
commita869a76e9947a0e975de0f06cd819fd82454472a (patch)
tree14a3ed049df929677bb97ada2188be0de8d2b51b
parent5082b970846484b69c31268855326ac8df370d29 (diff)
feat: add start of linear regression in jax
-rw-r--r--jax_optimization_tutorial/linear-regression.ipynb460
1 files changed, 460 insertions, 0 deletions
diff --git a/jax_optimization_tutorial/linear-regression.ipynb b/jax_optimization_tutorial/linear-regression.ipynb
new file mode 100644
index 0000000..403c55d
--- /dev/null
+++ b/jax_optimization_tutorial/linear-regression.ipynb
@@ -0,0 +1,460 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "3387e989",
+ "metadata": {},
+ "source": [
+ "# Linear Regression with JAX\n",
+ "Source: https://danielrothenberg.com/blog/2020/Sep/jax-first-steps-pt1/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "36893279",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import numpy as np\n",
+ "import jax.numpy as jnp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "565bff63",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax import random\n",
+ "\n",
+ "def make_key():\n",
+ " \"\"\" Helper function to generate a key for jax's parallel PRNG \n",
+ " using standard numpy random functions. \n",
+ "\n",
+ " \"\"\"\n",
+ " seed = np.random.randint(2**16 - 1)\n",
+ " return random.PRNGKey(seed)\n",
+ "\n",
+ "n = 100\n",
+ "rands = random.uniform(make_key(), shape=(n, ), minval=-1, maxval=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "11fa0b53",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]\n",
+ " [1. 1.]]\n",
+ "[[-0.49573565 1. ]\n",
+ " [-0.75743556 1. ]\n",
+ " [-0.10826278 1. ]\n",
+ " [-0.06045723 1. ]\n",
+ " [-0.6449168 1. ]\n",
+ " [-0.73400044 1. ]\n",
+ " [-0.37836146 1. ]\n",
+ " [-0.01198149 1. ]\n",
+ " [ 0.61107445 1. ]\n",
+ " [ 0.29680896 1. ]\n",
+ " [ 0.9514713 1. ]\n",
+ " [-0.94317484 1. ]\n",
+ " [ 0.5943241 1. ]\n",
+ " [-0.00902605 1. ]\n",
+ " [ 0.28311515 1. ]\n",
+ " [-0.1576097 1. ]\n",
+ " [ 0.50905204 1. ]\n",
+ " [ 0.40418315 1. ]\n",
+ " [-0.13876987 1. ]\n",
+ " [-0.10519576 1. ]\n",
+ " [ 0.4820497 1. ]\n",
+ " [ 0.16001129 1. ]\n",
+ " [-0.78287053 1. ]\n",
+ " [-0.83056164 1. ]\n",
+ " [-0.45683146 1. ]\n",
+ " [ 0.65588903 1. ]\n",
+ " [-0.30958128 1. ]\n",
+ " [-0.57034206 1. ]\n",
+ " [ 0.11413717 1. ]\n",
+ " [ 0.30006003 1. ]\n",
+ " [ 0.3948493 1. ]\n",
+ " [ 0.93246984 1. ]\n",
+ " [ 0.08236265 1. ]\n",
+ " [ 0.5821161 1. ]\n",
+ " [ 0.01212478 1. ]\n",
+ " [ 0.9328327 1. ]\n",
+ " [ 0.09593153 1. ]\n",
+ " [ 0.71609616 1. ]\n",
+ " [ 0.5419252 1. ]\n",
+ " [-0.78604007 1. ]\n",
+ " [ 0.9995785 1. ]\n",
+ " [-0.48301148 1. ]\n",
+ " [ 0.93194747 1. ]\n",
+ " [ 0.6257863 1. ]\n",
+ " [-0.42878604 1. ]\n",
+ " [-0.11879826 1. ]\n",
+ " [-0.60145855 1. ]\n",
+ " [ 0.25006437 1. ]\n",
+ " [ 0.4595716 1. ]\n",
+ " [ 0.48950648 1. ]\n",
+ " [ 0.77740407 1. ]\n",
+ " [ 0.17917514 1. ]\n",
+ " [ 0.52525926 1. ]\n",
+ " [ 0.7628691 1. ]\n",
+ " [-0.6807344 1. ]\n",
+ " [ 0.14769888 1. ]\n",
+ " [ 0.8373101 1. ]\n",
+ " [-0.90767694 1. ]\n",
+ " [ 0.90058756 1. ]\n",
+ " [-0.4039154 1. ]\n",
+ " [-0.9033947 1. ]\n",
+ " [ 0.8437958 1. ]\n",
+ " [-0.75383854 1. ]\n",
+ " [-0.7118697 1. ]\n",
+ " [ 0.2737422 1. ]\n",
+ " [-0.9684801 1. ]\n",
+ " [ 0.20652676 1. ]\n",
+ " [ 0.01346517 1. ]\n",
+ " [-0.9524138 1. ]\n",
+ " [-0.08028603 1. ]\n",
+ " [-0.71781445 1. ]\n",
+ " [ 0.08346677 1. ]\n",
+ " [ 0.6743789 1. ]\n",
+ " [ 0.01168156 1. ]\n",
+ " [ 0.0992105 1. ]\n",
+ " [-0.10572052 1. ]\n",
+ " [-0.07284904 1. ]\n",
+ " [-0.66415024 1. ]\n",
+ " [ 0.37942672 1. ]\n",
+ " [ 0.16580987 1. ]\n",
+ " [ 0.15668988 1. ]\n",
+ " [-0.4482379 1. ]\n",
+ " [-0.9518373 1. ]\n",
+ " [-0.22508144 1. ]\n",
+ " [-0.87991786 1. ]\n",
+ " [ 0.02797484 1. ]\n",
+ " [ 0.5076406 1. ]\n",
+ " [-0.40049028 1. ]\n",
+ " [ 0.55776215 1. ]\n",
+ " [-0.44021344 1. ]\n",
+ " [ 0.17129445 1. ]\n",
+ " [-0.4132347 1. ]\n",
+ " [ 0.4849968 1. ]\n",
+ " [ 0.7109468 1. ]\n",
+ " [ 0.66096044 1. ]\n",
+ " [-0.56528807 1. ]\n",
+ " [-0.55501866 1. ]\n",
+ " [-0.38958454 1. ]\n",
+ " [-0.17273688 1. ]\n",
+ " [-0.00741458 1. ]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# We have to use functions to update an array unlike numpy as JAX does not mutate arrays in-place\n",
+ "\n",
+ "x = jnp.ones((n, 2))\n",
+ "print(x)\n",
+ "x = x.at[:, 0].set(rands)\n",
+ "print(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a9a1e2fb",
+ "metadata": {},
+ "source": [
+ "# Simple first model\n",
+ "\n",
+ "$$\n",
+ " y = X\\beta + \\epsilon\n",
+ "$$\n",
+ "\n",
+ "where $\\beta = [\\beta_0, \\beta_1]$ where $\\beta_0$ a slope relating x and y abd $\\beta_1$ is an offset bias, and $\\epsilon$ is an uncorrelated noise modeled as a normal distribution, $\\mathcal{N}(0, 1)$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "7bf67529",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Fix our true slope and bias\n",
+ "slope, bias = 3.0, 2.0\n",
+ "beta_true = jnp.array((slope, bias))\n",
+ "eps = random.normal(make_key(), shape=(n,))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "5a6d3ce8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y = (x @ beta_true) + eps"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "462c84ec",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<matplotlib.collections.PathCollection at 0x7f8b306678b0>"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.scatter(x[:,0], y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "547a51fd",
+ "metadata": {},
+ "source": [
+ "# Analytical Model Fitting\n",
+ "\n",
+ "\n",
+ "If the problem is simple enough, we can find a closed form analytical solution to the problem. We seek to find the $\\beta$ that predicts the relationship between x and y. If we use $\\mathcal{L}(\\beta)=||\\mathbf{X}\\beta-\\mathbf{Y}||$, the analytical solution is:\n",
+ "\n",
+ "$$\n",
+ " \\frac{\\partial \\mathcal{L}(\\beta)}{\\partial \\beta} = -2 \\mathbf{Y}^T\\mathbf{X}+2\\beta^T\\mathbf{X}^T\\mathbf{X}\n",
+ "$$\n",
+ "\n",
+ "And solving for 0 gives:\n",
+ "\n",
+ "$$\n",
+ "\\beta = (\\mathbf{X}^T\\mathbf{X})^{-1}\\mathbf{X}^T\\mathbf{Y}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "61ad3658",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2.9267178 1.855265 ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "beta_ols = jnp.linalg.inv(x.T @ x) @ (x.T @ y)\n",
+ "print(beta_ols)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "2430ba06",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.scatter(x[:, 0], y)\n",
+ "\n",
+ "xs = np.linspace(-1, 1)\n",
+ "xs = np.stack([xs, np.ones_like(xs)], axis=1)\n",
+ "ys_true = xs@beta_true\n",
+ "ys_fit = xs@beta_ols\n",
+ "plt.plot(xs[:, 0], ys_true, color='k', lw=2, label=\"True\")\n",
+ "plt.plot(xs[:, 0], ys_fit, color='orange', lw=1, label='OLS fit')\n",
+ "plt.legend();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "380ba33a",
+ "metadata": {},
+ "source": [
+ "# TBC..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cbd1bd0d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}