{ "cells": [ { "cell_type": "markdown", "id": "f48669df", "metadata": {}, "source": [ "# Solving Optimization Problems with JAX\n", "Source: https://medium.com/swlh/solving-optimization-problems-with-jax-98376508bd4f" ] }, { "cell_type": "code", "execution_count": 7, "id": "02e5a88f", "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "from jax import grad, jit, vmap\n", "from jax import random\n", "from jax import jacfwd, jacrev\n", "from jax.numpy import linalg\n", "from numpy import nanargmin, nanargmax" ] }, { "cell_type": "code", "execution_count": 4, "id": "9bad5173", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gpu\n" ] } ], "source": [ "from jax.lib import xla_bridge\n", "print(xla_bridge.get_backend().platform)" ] }, { "cell_type": "code", "execution_count": 2, "id": "afe9f53c", "metadata": {}, "outputs": [], "source": [ "key = random.PRNGKey(42)" ] }, { "cell_type": "code", "execution_count": 3, "id": "b8bc9bbe", "metadata": {}, "outputs": [], "source": [ "def f(x): return 3 * x[0] ** 2" ] }, { "cell_type": "code", "execution_count": 4, "id": "aebb2724", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([12.], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gradf = grad(f)\n", "gradf(jnp.array([2.0]))" ] }, { "cell_type": "code", "execution_count": 5, "id": "9d891860", "metadata": {}, "outputs": [], "source": [ "def circle(x): return x[0] ** 2 + x[1] ** 2" ] }, { "cell_type": "code", "execution_count": 6, "id": "6a70bd20", "metadata": {}, "outputs": [], "source": [ "J = jacfwd(circle)" ] }, { "cell_type": "code", "execution_count": 7, "id": "dff71dc6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([2., 4.], dtype=float32)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "J(jnp.array([1.0, 2.0]))" ] }, { "cell_type": "code", "execution_count": 14, "id": "b90095a4", "metadata": {}, "outputs": [], "source": [ "def hessian(f): return jacfwd(jacrev(f))" ] }, { "cell_type": "code", "execution_count": 15, "id": "eb4124f0", "metadata": {}, "outputs": [], "source": [ "H = hessian(circle)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a6d57fb1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[2., 0.],\n", " [0., 2.]], dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "H(jnp.array([1.0, 2.0]))" ] }, { "cell_type": "code", "execution_count": 17, "id": "b50fe2a5", "metadata": {}, "outputs": [], "source": [ "# Paper problem" ] }, { "cell_type": "code", "execution_count": 18, "id": "89a1d95d", "metadata": {}, "outputs": [], "source": [ "def y(x): return ((x * jnp.sqrt(12 * x - 36)) / (2 * (x - 3)))\n", "def L(x): return jnp.sqrt(x ** 2 + y(x) ** 2)" ] }, { "cell_type": "code", "execution_count": 19, "id": "402c0f64", "metadata": {}, "outputs": [], "source": [ "gradL = grad(L)" ] }, { "cell_type": "code", "execution_count": 20, "id": "5bf3e5c2", "metadata": {}, "outputs": [], "source": [ "def minGD(x): return x - 0.01 * gradL(x)" ] }, { "cell_type": "code", "execution_count": 21, "id": "e340a396", "metadata": {}, "outputs": [], "source": [ "domain = jnp.linspace(3.0, 5.0, num=50)" ] }, { "cell_type": "code", "execution_count": 22, "id": "1d054e99", "metadata": {}, "outputs": [], "source": [ "vfunGD = vmap(minGD)" ] }, { "cell_type": "code", "execution_count": 25, "id": "33a48ebf", "metadata": {}, "outputs": [], "source": [ "for epoch in range(50):\n", " domain = vfunGD(domain)" ] }, { "cell_type": "code", "execution_count": 28, "id": "d698e29a", "metadata": {}, "outputs": [], "source": [ "minfun = vmap(L)\n", "minimums = minfun(domain)" ] }, { "cell_type": "code", "execution_count": 29, "id": "eed252b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([ nan, 8.337169 , 7.8126216, 7.8554583, 7.8657427,\n", " 7.8686085, 7.869442 , 7.869495 , 7.8691173, 7.8684244,\n", " 7.867449 , 7.8662043, 7.8646903, 7.8629084, 7.860858 ,\n", " 7.8585405, 7.8559637, 7.853135 , 7.8500714, 7.8467846,\n", " 7.843305 , 7.8396564, 7.8358707, 7.8319883, 7.828048 ,\n", " 7.8240952, 7.8201776, 7.8163476, 7.812655 , 7.809155 ,\n", " 7.805899 , 7.8029428, 7.800338 , 7.7981334, 7.7963796,\n", " 7.795119 , 7.7943954, 7.7942476, 7.7947106, 7.795813 ,\n", " 7.797586 , 7.8000493, 7.8032246, 7.807126 , 7.811767 ,\n", " 7.8171563, 7.8233 , 7.830198 , 7.837856 , 7.846267 ], dtype=float32)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "minimums" ] }, { "cell_type": "code", "execution_count": 30, "id": "92a900e8", "metadata": {}, "outputs": [], "source": [ "arglist = nanargmin(minimums)" ] }, { "cell_type": "code", "execution_count": 31, "id": "a520c53f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "37" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arglist" ] }, { "cell_type": "code", "execution_count": 32, "id": "e5da0824", "metadata": {}, "outputs": [], "source": [ "argmin = domain[arglist]" ] }, { "cell_type": "code", "execution_count": 33, "id": "d2d53cec", "metadata": {}, "outputs": [], "source": [ "minimum = minimums[arglist]" ] }, { "cell_type": "code", "execution_count": 34, "id": "930267ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is 4.505731105804443\n" ] } ], "source": [ "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "c7910312", "metadata": {}, "source": [ "# Newton's method\n", "\n", "Solve for where $f(x) = 0$\n", "\n", "$$\n", " x_{n+1} = x_n - \\frac{f(x_n)}{f^\\prime(x_n)}\n", "$$\n", "\n", "But we search for $f^\\prime(x) = 0$\n", "\n", "$$\n", " x_{n+1} = x_n - \\frac{f^{\\prime}(x_n)}{f^{\\prime\\prime}(x_n)}\n", " $$" ] }, { "cell_type": "code", "execution_count": 35, "id": "426a3238", "metadata": {}, "outputs": [], "source": [ "gradL = grad(L)" ] }, { "cell_type": "code", "execution_count": 36, "id": "21cd42fa", "metadata": {}, "outputs": [], "source": [ "gradL2 = grad(gradL)" ] }, { "cell_type": "code", "execution_count": 39, "id": "4c5b3265", "metadata": {}, "outputs": [], "source": [ "def min_newton(x): return x - gradL(x) / gradL2(x)" ] }, { "cell_type": "code", "execution_count": 38, "id": "4dd342ac", "metadata": {}, "outputs": [], "source": [ "domain = jnp.linspace(3.0, 5.0, num=50)" ] }, { "cell_type": "code", "execution_count": 40, "id": "235d15b4", "metadata": {}, "outputs": [], "source": [ "f_newton = vmap(min_newton)" ] }, { "cell_type": "code", "execution_count": 41, "id": "9a88cd8b", "metadata": {}, "outputs": [], "source": [ "for epoch in range(50):\n", " domain = f_newton(domain)" ] }, { "cell_type": "code", "execution_count": 43, "id": "529dbb0e", "metadata": {}, "outputs": [], "source": [ "mins = minfun(domain)" ] }, { "cell_type": "code", "execution_count": 44, "id": "65aa7311", "metadata": {}, "outputs": [], "source": [ "arglist = nanargmin(minimums)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]" ] }, { "cell_type": "code", "execution_count": 45, "id": "c0327e0e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is 4.499999523162842\n" ] } ], "source": [ "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "1180ee4c", "metadata": {}, "source": [ "# Multivariable Optimization\n", "\n", "$$\n", " X_{n+1} = X_n - 0.01 \\nabla f(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 48, "id": "2a968948", "metadata": {}, "outputs": [], "source": [ "def paraboloid(x): return (x[0] * x[1] - 2) ** 2 + (x[1] - 3) ** 2" ] }, { "cell_type": "code", "execution_count": 49, "id": "f893af70", "metadata": {}, "outputs": [], "source": [ "f_min = vmap(paraboloid)" ] }, { "cell_type": "code", "execution_count": 50, "id": "0b035649", "metadata": {}, "outputs": [], "source": [ "J = jacfwd(paraboloid)" ] }, { "cell_type": "code", "execution_count": 51, "id": "edd62878", "metadata": {}, "outputs": [], "source": [ "def jacobian_min(x): return x - 0.01 * J(x)" ] }, { "cell_type": "code", "execution_count": 69, "id": "81c302fd", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)" ] }, { "cell_type": "code", "execution_count": 70, "id": "70d0555a", "metadata": {}, "outputs": [], "source": [ "v_jacobian_min = vmap(jacobian_min)" ] }, { "cell_type": "code", "execution_count": 71, "id": "d53b2818", "metadata": {}, "outputs": [], "source": [ "for epoch in range(150):\n", " domain = v_jacobian_min(domain)" ] }, { "cell_type": "code", "execution_count": 72, "id": "93c3cfdf", "metadata": {}, "outputs": [], "source": [ "mins = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 73, "id": "c2f9afa2", "metadata": {}, "outputs": [], "source": [ "arglist = nanargmin(mins)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]" ] }, { "cell_type": "code", "execution_count": 74, "id": "91ebaa54", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is [-0.54916656 0.60931224]\n" ] } ], "source": [ "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "e5e16d79", "metadata": {}, "source": [ "# The Hessian\n", "\n", "Similar to the Newton method:\n", "$$\n", " X_{n+1} = X_n - H^{-1}(X_n) \\nabla f(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 60, "id": "0c11e322", "metadata": {}, "outputs": [], "source": [ "def hessian(f): return jacfwd(jacrev(f))" ] }, { "cell_type": "code", "execution_count": 61, "id": "7a5d3955", "metadata": {}, "outputs": [], "source": [ "H = hessian(paraboloid)" ] }, { "cell_type": "code", "execution_count": 62, "id": "8b32404d", "metadata": {}, "outputs": [], "source": [ "def hessian_min(x): return x - 0.01 * linalg.inv(H(x)) @ J(x)" ] }, { "cell_type": "code", "execution_count": 75, "id": "050f46e5", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)" ] }, { "cell_type": "code", "execution_count": 76, "id": "cba5bd2b", "metadata": {}, "outputs": [], "source": [ "v_hessian_min = vmap(hessian_min)" ] }, { "cell_type": "code", "execution_count": 77, "id": "27b219df", "metadata": {}, "outputs": [], "source": [ "for epoch in range(150):\n", " domain = v_hessian_min(domain)" ] }, { "cell_type": "code", "execution_count": 78, "id": "dfdf027b", "metadata": {}, "outputs": [], "source": [ "mins = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 79, "id": "db074a43", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is [-0.74005073 -0.00774691]\n" ] } ], "source": [ "arglist = nanargmin(mins)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]\n", "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "e2cc6b7d", "metadata": {}, "source": [ "# Mutlivariable Constrained Optimization\n", "\n", "Lagrangian multipliers:\n", "\n", "$$\n", "\\begin{align}\n", "\\nabla f = \\lambda \\nabla g\\\\\n", "\\nabla f - \\lambda \\nabla g = 0\\\\\n", "\\nabla (f - \\lambda g) = 0\\\\\n", "\\nabla (\\mathcal{L}) = 0\\\\\n", "\\end{align}\n", "$$\n", "\n", "Our iterative approach will be:\n", "\n", "$$\n", "X_{n+1} = X_n - \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 80, "id": "a8ff98bc", "metadata": {}, "outputs": [], "source": [ "def f(x): return 4 * (x[0] ** 2) * x[1]\n", "def g(x): return x[0]**2 + x[1]**2 - 3" ] }, { "cell_type": "code", "execution_count": 81, "id": "ba2ce45c", "metadata": {}, "outputs": [], "source": [ "f_min = vmap(f)" ] }, { "cell_type": "code", "execution_count": 82, "id": "c0d03ac0", "metadata": {}, "outputs": [], "source": [ "def lagrange(l): return f(l[0:2]) - l[2]*g(l[0:2])" ] }, { "cell_type": "code", "execution_count": 83, "id": "fc1d5af3", "metadata": {}, "outputs": [], "source": [ "L = jacfwd(lagrange)" ] }, { "cell_type": "code", "execution_count": 84, "id": "60f8134b", "metadata": {}, "outputs": [], "source": [ "gradL = jacfwd(L)" ] }, { "cell_type": "code", "execution_count": 85, "id": "ef6ddc53", "metadata": {}, "outputs": [], "source": [ "def lagrange_min(l): return l - linalg.inv(gradL(l)) @ L(l)" ] }, { "cell_type": "code", "execution_count": 86, "id": "c67b6785", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50,3), dtype=\"float32\", minval=-5.0, maxval=5.0)" ] }, { "cell_type": "code", "execution_count": 87, "id": "d61873ab", "metadata": {}, "outputs": [], "source": [ "v_lagrange_min = vmap(lagrange_min)" ] }, { "cell_type": "code", "execution_count": 88, "id": "73539993", "metadata": {}, "outputs": [], "source": [ "for epoch in range(150):\n", " domain = v_lagrange_min(domain)" ] }, { "cell_type": "code", "execution_count": 89, "id": "7f4d91c7", "metadata": {}, "outputs": [], "source": [ "mins = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 91, "id": "38f56026", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.8554582595825195 the arg min is [-1.4142135 -1. -4. ]\n" ] } ], "source": [ "arglist = nanargmin(mins)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]\n", "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "380c75c3", "metadata": {}, "source": [ "# Three Variable Multivariable Constrained Optimization\n", "$$\n", "\\begin{align}\n", "f(x, y, z)& = xyz\\\\\n", "g(x, y, z)& = 2xy + 2yz + 2xz - 64\\\\\n", "\\end{align}\n", "$$\n", "\n", "Our iterative approach will be:\n", "\n", "$$\n", "X_{n+1} = X_n - 0.1 \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 98, "id": "820312ab", "metadata": {}, "outputs": [], "source": [ "def f(x): return x[0]*x[1]*x[2]\n", "def g(x): return 2*x[0]*x[1] + 2*x[1]*x[2] + 2*x[0]*x[2] - 64\n", "\n", "f_min = vmap(f)\n", "\n", "def lagrange(l): return f(l[0:3]) - l[3]*g(l[0:3])\n", "\n", "L = jacfwd(lagrange)\n", "gradL = jacfwd(L)" ] }, { "cell_type": "code", "execution_count": 99, "id": "61f9db1e", "metadata": {}, "outputs": [], "source": [ "def lagrangian_min(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)" ] }, { "cell_type": "code", "execution_count": 100, "id": "9e8d959f", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50, 4), dtype=\"float32\", minval = 0, maxval=10)" ] }, { "cell_type": "code", "execution_count": 101, "id": "c3b8c4fa", "metadata": {}, "outputs": [], "source": [ "v_lagrangian_min = vmap(lagrangian_min)" ] }, { "cell_type": "code", "execution_count": 102, "id": "8e83319e", "metadata": {}, "outputs": [], "source": [ "for epoch in range(200):\n", " domain = v_lagrangian_min(domain)" ] }, { "cell_type": "code", "execution_count": 103, "id": "d5f28369", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The minimum is 34.83720016479492, the argmin is (3.2659873962402344,3.2659854888916016,3.2659873962402344), the lagrangian is 0.8164968490600586\n" ] } ], "source": [ "maximums = f_min(domain)\n", "\n", "arglist = nanargmax(maximums)\n", "argmin = domain[arglist]\n", "minimum = maximums[arglist]\n", "\n", "print(\"The minimum is {}, the argmin is ({},{},{}), the lagrangian is {}\".format(minimum,\n", " argmin[0], argmin[1], argmin[2], argmin[3]))" ] }, { "cell_type": "markdown", "id": "0bf20ea5", "metadata": {}, "source": [ "# Multivariable Multiconstrained Optimization\n", "\n", "The general form for of the Lagrangian function can be written as:\n", "\n", "$$\n", " \\mathcal{L}(X)=\\nabla f(X) - def f(x): return 13*x[0]**2 + 10*x[0]*x[1] + 7*x[1]**2 + x[0] + x[1]
def g(x): return 2*x[0] - 5*x[1] - 2
def h(x): return x[0] + x[1] - 1" ] }, { "cell_type": "code", "execution_count": 9, "id": "1fe85659", "metadata": {}, "outputs": [], "source": [ "f_min = vmap(f)\n", "\n", "def lagrange(l): return f(l[0:2]) - l[2]*g(l[:2]) - l[3]*h(l[:2])\n", "\n", "L = jacfwd(lagrange)\n", "gradL = jacfwd(L)\n", "\n", "def solve_lagrange(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)" ] }, { "cell_type": "code", "execution_count": 11, "id": "4497a49e", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(300,4), dtype=\"float32\", minval=-4, maxval=1)\n", "\n", "v_solve = vmap(solve_lagrange)\n", "for epoch in range(300):\n", " domain = v_solve(domain)" ] }, { "cell_type": "code", "execution_count": 12, "id": "f0e0fa5d", "metadata": {}, "outputs": [], "source": [ "maximums = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 18, "id": "ce6ee395", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The minimum is 13.999994277954102, the argmin is (0.9999997615814209,-1.063081178642733e-08), the lagrangian is 2.2857134342193604 and 22.428564071655273\n" ] } ], "source": [ "arglist = nanargmax(maximums)\n", "argmin = domain[arglist]\n", "minimum = maximums[arglist]\n", "\n", "print(\"The minimum is {}, the argmin is ({},{}), the lagrangian is {} and {}\".format(minimum,\n", " argmin[0], argmin[1], argmin[2], argmin[3]))"