From 53b256b0a49ff11b42c04f0a11272899571b389d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 21 Mar 2022 23:55:53 +0100 Subject: feat(constrained optimization): basic tutorial --- jax_optimization_tutorial/__init__.py | 1 + jax_optimization_tutorial/basic-optimization.ipynb | 981 +++++++++++++++++++++ 2 files changed, 982 insertions(+) create mode 100644 jax_optimization_tutorial/__init__.py create mode 100644 jax_optimization_tutorial/basic-optimization.ipynb (limited to 'jax_optimization_tutorial') diff --git a/jax_optimization_tutorial/__init__.py b/jax_optimization_tutorial/__init__.py new file mode 100644 index 0000000..b794fd4 --- /dev/null +++ b/jax_optimization_tutorial/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' diff --git a/jax_optimization_tutorial/basic-optimization.ipynb b/jax_optimization_tutorial/basic-optimization.ipynb new file mode 100644 index 0000000..eb6a339 --- /dev/null +++ b/jax_optimization_tutorial/basic-optimization.ipynb @@ -0,0 +1,981 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f48669df", + "metadata": {}, + "source": [ + "# Solving Optimization Problems with JAX\n", + "Source: https://medium.com/swlh/solving-optimization-problems-with-jax-98376508bd4f" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "02e5a88f", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as np\n", + "from jax import grad, jit, vmap\n", + "from jax import random\n", + "from jax import jacfwd, jacrev\n", + "from jax.numpy import linalg\n", + "from numpy import nanargmin, nanargmax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "afe9f53c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + } + ], + "source": [ + "key = random.PRNGKey(42)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b8bc9bbe", + "metadata": {}, + "outputs": [], + "source": [ + "def f(x): return 3 * x[0] ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "aebb2724", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([12.], dtype=float32)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gradf = grad(f)\n", + "gradf(np.array([2.0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9d891860", + "metadata": {}, + "outputs": [], + "source": [ + "def circle(x): return x[0] ** 2 + x[1] ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6a70bd20", + "metadata": {}, + "outputs": [], + "source": [ + "J = jacfwd(circle)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "dff71dc6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([2., 4.], dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J(np.array([1.0, 2.0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b90095a4", + "metadata": {}, + "outputs": [], + "source": [ + "def hessian(f): return jacfwd(jacrev(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "eb4124f0", + "metadata": {}, + "outputs": [], + "source": [ + "H = hessian(circle)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a6d57fb1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([[2., 0.],\n", + " [0., 2.]], dtype=float32)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "H(np.array([1.0, 2.0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b50fe2a5", + "metadata": {}, + "outputs": [], + "source": [ + "# Paper problem" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "89a1d95d", + "metadata": {}, + "outputs": [], + "source": [ + "def y(x): return ((x * np.sqrt(12 * x - 36)) / (2 * (x - 3)))\n", + "def L(x): return np.sqrt(x ** 2 + y(x) ** 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "402c0f64", + "metadata": {}, + "outputs": [], + "source": [ + "gradL = grad(L)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5bf3e5c2", + "metadata": {}, + "outputs": [], + "source": [ + "def minGD(x): return x - 0.01 * gradL(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "e340a396", + "metadata": {}, + "outputs": [], + "source": [ + "domain = np.linspace(3.0, 5.0, num=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "1d054e99", + "metadata": {}, + "outputs": [], + "source": [ + "vfunGD = vmap(minGD)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "33a48ebf", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(50):\n", + " domain = vfunGD(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "d698e29a", + "metadata": {}, + "outputs": [], + "source": [ + "minfun = vmap(L)\n", + "minimums = minfun(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "eed252b6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([ nan, 8.337169 , 7.8126216, 7.8554583, 7.8657427,\n", + " 7.8686085, 7.869442 , 7.869495 , 7.8691173, 7.8684244,\n", + " 7.867449 , 7.8662043, 7.8646903, 7.8629084, 7.860858 ,\n", + " 7.8585405, 7.8559637, 7.853135 , 7.8500714, 7.8467846,\n", + " 7.843305 , 7.8396564, 7.8358707, 7.8319883, 7.828048 ,\n", + " 7.8240952, 7.8201776, 7.8163476, 7.812655 , 7.809155 ,\n", + " 7.805899 , 7.8029428, 7.800338 , 7.7981334, 7.7963796,\n", + " 7.795119 , 7.7943954, 7.7942476, 7.7947106, 7.795813 ,\n", + " 7.797586 , 7.8000493, 7.8032246, 7.807126 , 7.811767 ,\n", + " 7.8171563, 7.8233 , 7.830198 , 7.837856 , 7.846267 ], dtype=float32)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "minimums" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "92a900e8", + "metadata": {}, + "outputs": [], + "source": [ + "arglist = nanargmin(minimums)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "a520c53f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "37" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "arglist" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "e5da0824", + "metadata": {}, + "outputs": [], + "source": [ + "argmin = domain[arglist]" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "d2d53cec", + "metadata": {}, + "outputs": [], + "source": [ + "minimum = minimums[arglist]" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "930267ad", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the minimum is 7.794247627258301 the arg min is 4.505731105804443\n" + ] + } + ], + "source": [ + "print(f\"the minimum is {minimum} the arg min is {argmin}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c7910312", + "metadata": {}, + "source": [ + "# Newton's method\n", + "\n", + "Solve for where $f(x) = 0$\n", + "\n", + "$$\n", + " x_{n+1} = x_n - \\frac{f(x_n)}{f^\\prime(x_n)}\n", + "$$\n", + "\n", + "But we search for $f^\\prime(x) = 0$\n", + "\n", + "$$\n", + " x_{n+1} = x_n - \\frac{f^{\\prime}(x_n)}{f^{\\prime\\prime}(x_n)}\n", + " $$" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "426a3238", + "metadata": {}, + "outputs": [], + "source": [ + "gradL = grad(L)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "21cd42fa", + "metadata": {}, + "outputs": [], + "source": [ + "gradL2 = grad(gradL)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "4c5b3265", + "metadata": {}, + "outputs": [], + "source": [ + "def min_newton(x): return x - gradL(x) / gradL2(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "4dd342ac", + "metadata": {}, + "outputs": [], + "source": [ + "domain = np.linspace(3.0, 5.0, num=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "235d15b4", + "metadata": {}, + "outputs": [], + "source": [ + "f_newton = vmap(min_newton)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "9a88cd8b", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(50):\n", + " domain = f_newton(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "529dbb0e", + "metadata": {}, + "outputs": [], + "source": [ + "mins = minfun(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "65aa7311", + "metadata": {}, + "outputs": [], + "source": [ + "arglist = nanargmin(minimums)\n", + "argmin = domain[arglist]\n", + "minimum = minimums[arglist]" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "c0327e0e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the minimum is 7.794247627258301 the arg min is 4.499999523162842\n" + ] + } + ], + "source": [ + "print(f\"the minimum is {minimum} the arg min is {argmin}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1180ee4c", + "metadata": {}, + "source": [ + "# Multivariable Optimization\n", + "\n", + "$$\n", + " X_{n+1} = X_n - 0.01 \\nabla f(X_n)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "2a968948", + "metadata": {}, + "outputs": [], + "source": [ + "def paraboloid(x): return (x[0] * x[1] - 2) ** 2 + (x[1] - 3) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "f893af70", + "metadata": {}, + "outputs": [], + "source": [ + "f_min = vmap(paraboloid)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "0b035649", + "metadata": {}, + "outputs": [], + "source": [ + "J = jacfwd(paraboloid)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "edd62878", + "metadata": {}, + "outputs": [], + "source": [ + "def jacobian_min(x): return x - 0.01 * J(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "81c302fd", + "metadata": {}, + "outputs": [], + "source": [ + "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "70d0555a", + "metadata": {}, + "outputs": [], + "source": [ + "v_jacobian_min = vmap(jacobian_min)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "d53b2818", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(150):\n", + " domain = v_jacobian_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "93c3cfdf", + "metadata": {}, + "outputs": [], + "source": [ + "mins = f_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "c2f9afa2", + "metadata": {}, + "outputs": [], + "source": [ + "arglist = nanargmin(mins)\n", + "argmin = domain[arglist]\n", + "minimum = minimums[arglist]" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "91ebaa54", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the minimum is 7.794247627258301 the arg min is [-0.54916656 0.60931224]\n" + ] + } + ], + "source": [ + "print(f\"the minimum is {minimum} the arg min is {argmin}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e5e16d79", + "metadata": {}, + "source": [ + "# The Hessian\n", + "\n", + "Similar to the Newton method:\n", + "$$\n", + " X_{n+1} = X_n - H^{-1}(X_n) \\nabla f(X_n)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "0c11e322", + "metadata": {}, + "outputs": [], + "source": [ + "def hessian(f): return jacfwd(jacrev(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "7a5d3955", + "metadata": {}, + "outputs": [], + "source": [ + "H = hessian(paraboloid)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "8b32404d", + "metadata": {}, + "outputs": [], + "source": [ + "def hessian_min(x): return x - 0.01 * linalg.inv(H(x)) @ J(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "050f46e5", + "metadata": {}, + "outputs": [], + "source": [ + "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "cba5bd2b", + "metadata": {}, + "outputs": [], + "source": [ + "v_hessian_min = vmap(hessian_min)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "27b219df", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(150):\n", + " domain = v_hessian_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "dfdf027b", + "metadata": {}, + "outputs": [], + "source": [ + "mins = f_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "db074a43", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the minimum is 7.794247627258301 the arg min is [-0.74005073 -0.00774691]\n" + ] + } + ], + "source": [ + "arglist = nanargmin(mins)\n", + "argmin = domain[arglist]\n", + "minimum = minimums[arglist]\n", + "print(f\"the minimum is {minimum} the arg min is {argmin}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e2cc6b7d", + "metadata": {}, + "source": [ + "# Mutlivariable Constrained Optimization\n", + "\n", + "Lagrangian multipliers:\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\nabla f = \\lambda \\nabla g\\\\\n", + "\\nabla f - \\lambda \\nabla g = 0\\\\\n", + "\\nabla (f - \\lambda g) = 0\\\\\n", + "\\nabla (\\mathcal{L}) = 0\\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Our iterative approach will be:\n", + "\n", + "$$\n", + "X_{n+1} = X_n - \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "a8ff98bc", + "metadata": {}, + "outputs": [], + "source": [ + "def f(x): return 4 * (x[0] ** 2) * x[1]\n", + "def g(x): return x[0]**2 + x[1]**2 - 3" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "ba2ce45c", + "metadata": {}, + "outputs": [], + "source": [ + "f_min = vmap(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "c0d03ac0", + "metadata": {}, + "outputs": [], + "source": [ + "def lagrange(l): return f(l[0:2]) - l[2]*g(l[0:2])" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "fc1d5af3", + "metadata": {}, + "outputs": [], + "source": [ + "L = jacfwd(lagrange)" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "60f8134b", + "metadata": {}, + "outputs": [], + "source": [ + "gradL = jacfwd(L)" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "ef6ddc53", + "metadata": {}, + "outputs": [], + "source": [ + "def lagrange_min(l): return l - linalg.inv(gradL(l)) @ L(l)" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "c67b6785", + "metadata": {}, + "outputs": [], + "source": [ + "domain = random.uniform(key, shape=(50,3), dtype=\"float32\", minval=-5.0, maxval=5.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "d61873ab", + "metadata": {}, + "outputs": [], + "source": [ + "v_lagrange_min = vmap(lagrange_min)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "73539993", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(150):\n", + " domain = v_lagrange_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "7f4d91c7", + "metadata": {}, + "outputs": [], + "source": [ + "mins = f_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "38f56026", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the minimum is 7.8554582595825195 the arg min is [-1.4142135 -1. -4. ]\n" + ] + } + ], + "source": [ + "arglist = nanargmin(mins)\n", + "argmin = domain[arglist]\n", + "minimum = minimums[arglist]\n", + "print(f\"the minimum is {minimum} the arg min is {argmin}\")" + ] + }, + { + "cell_type": "markdown", + "id": "380c75c3", + "metadata": {}, + "source": [ + "# Three Variable Multivariable Constrained Optimization\n", + "$$\n", + "\\begin{align}\n", + "f(x, y, z)& = xyz\\\\\n", + "g(x, y, z)& = 2xy + 2yz + 2xz - 64\\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Our iterative approach will be:\n", + "\n", + "$$\n", + "X_{n+1} = X_n - 0.1 \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "820312ab", + "metadata": {}, + "outputs": [], + "source": [ + "def f(x): return x[0]*x[1]*x[2]\n", + "def g(x): return 2*x[0]*x[1] + 2*x[1]*x[2] + 2*x[0]*x[2] - 64\n", + "\n", + "f_min = vmap(f)\n", + "\n", + "def lagrange(l): return f(l[0:3]) - l[3]*g(l[0:3])\n", + "\n", + "L = jacfwd(lagrange)\n", + "gradL = jacfwd(L)" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "61f9db1e", + "metadata": {}, + "outputs": [], + "source": [ + "def lagrangian_min(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "9e8d959f", + "metadata": {}, + "outputs": [], + "source": [ + "domain = random.uniform(key, shape=(50, 4), dtype=\"float32\", minval = 0, maxval=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "c3b8c4fa", + "metadata": {}, + "outputs": [], + "source": [ + "v_lagrangian_min = vmap(lagrangian_min)" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "8e83319e", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(200):\n", + " domain = v_lagrangian_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "d5f28369", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The minimum is 34.83720016479492, the argmin is (3.2659873962402344,3.2659854888916016,3.2659873962402344), the lagrangian is 0.8164968490600586\n" + ] + } + ], + "source": [ + "maximums = f_min(domain)\n", + "\n", + "arglist = nanargmax(maximums)\n", + "argmin = domain[arglist]\n", + "minimum = maximums[arglist]\n", + "\n", + "print(\"The minimum is {}, the argmin is ({},{},{}), the lagrangian is {}\".format(minimum,\n", + " argmin[0], argmin[1], argmin[2], argmin[3]))" + ] + }, + { + "cell_type": "markdown", + "id": "0bf20ea5", + "metadata": {}, + "source": [ + "# Multivariable MultiConstrained Optimization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e834f8ba", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- cgit v1.2.3-70-g09d2