summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-03-21 23:55:53 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-03-21 23:55:53 +0100
commit53b256b0a49ff11b42c04f0a11272899571b389d (patch)
tree82f84c1c145fd7a36f88d5be9e1da5edb73493b7
feat(constrained optimization): basic tutorial
-rw-r--r--jax_optimization_tutorial/__init__.py1
-rw-r--r--jax_optimization_tutorial/basic-optimization.ipynb981
2 files changed, 982 insertions, 0 deletions
diff --git a/jax_optimization_tutorial/__init__.py b/jax_optimization_tutorial/__init__.py
new file mode 100644
index 0000000..b794fd4
--- /dev/null
+++ b/jax_optimization_tutorial/__init__.py
@@ -0,0 +1 @@
+__version__ = '0.1.0'
diff --git a/jax_optimization_tutorial/basic-optimization.ipynb b/jax_optimization_tutorial/basic-optimization.ipynb
new file mode 100644
index 0000000..eb6a339
--- /dev/null
+++ b/jax_optimization_tutorial/basic-optimization.ipynb
@@ -0,0 +1,981 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "f48669df",
+ "metadata": {},
+ "source": [
+ "# Solving Optimization Problems with JAX\n",
+ "Source: https://medium.com/swlh/solving-optimization-problems-with-jax-98376508bd4f"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "02e5a88f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax.numpy as np\n",
+ "from jax import grad, jit, vmap\n",
+ "from jax import random\n",
+ "from jax import jacfwd, jacrev\n",
+ "from jax.numpy import linalg\n",
+ "from numpy import nanargmin, nanargmax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "afe9f53c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
+ ]
+ }
+ ],
+ "source": [
+ "key = random.PRNGKey(42)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b8bc9bbe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def f(x): return 3 * x[0] ** 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "aebb2724",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DeviceArray([12.], dtype=float32)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gradf = grad(f)\n",
+ "gradf(np.array([2.0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "9d891860",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def circle(x): return x[0] ** 2 + x[1] ** 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "6a70bd20",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "J = jacfwd(circle)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "dff71dc6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DeviceArray([2., 4.], dtype=float32)"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "J(np.array([1.0, 2.0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "b90095a4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def hessian(f): return jacfwd(jacrev(f))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "eb4124f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "H = hessian(circle)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "a6d57fb1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DeviceArray([[2., 0.],\n",
+ " [0., 2.]], dtype=float32)"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "H(np.array([1.0, 2.0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "b50fe2a5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Paper problem"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "89a1d95d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def y(x): return ((x * np.sqrt(12 * x - 36)) / (2 * (x - 3)))\n",
+ "def L(x): return np.sqrt(x ** 2 + y(x) ** 2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "402c0f64",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gradL = grad(L)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "5bf3e5c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def minGD(x): return x - 0.01 * gradL(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "e340a396",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = np.linspace(3.0, 5.0, num=50)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "1d054e99",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vfunGD = vmap(minGD)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "33a48ebf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(50):\n",
+ " domain = vfunGD(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "d698e29a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "minfun = vmap(L)\n",
+ "minimums = minfun(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "eed252b6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DeviceArray([ nan, 8.337169 , 7.8126216, 7.8554583, 7.8657427,\n",
+ " 7.8686085, 7.869442 , 7.869495 , 7.8691173, 7.8684244,\n",
+ " 7.867449 , 7.8662043, 7.8646903, 7.8629084, 7.860858 ,\n",
+ " 7.8585405, 7.8559637, 7.853135 , 7.8500714, 7.8467846,\n",
+ " 7.843305 , 7.8396564, 7.8358707, 7.8319883, 7.828048 ,\n",
+ " 7.8240952, 7.8201776, 7.8163476, 7.812655 , 7.809155 ,\n",
+ " 7.805899 , 7.8029428, 7.800338 , 7.7981334, 7.7963796,\n",
+ " 7.795119 , 7.7943954, 7.7942476, 7.7947106, 7.795813 ,\n",
+ " 7.797586 , 7.8000493, 7.8032246, 7.807126 , 7.811767 ,\n",
+ " 7.8171563, 7.8233 , 7.830198 , 7.837856 , 7.846267 ], dtype=float32)"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "minimums"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "92a900e8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "arglist = nanargmin(minimums)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "a520c53f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "37"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "arglist"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "e5da0824",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "argmin = domain[arglist]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "d2d53cec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "minimum = minimums[arglist]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "930267ad",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the minimum is 7.794247627258301 the arg min is 4.505731105804443\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"the minimum is {minimum} the arg min is {argmin}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c7910312",
+ "metadata": {},
+ "source": [
+ "# Newton's method\n",
+ "\n",
+ "Solve for where $f(x) = 0$\n",
+ "\n",
+ "$$\n",
+ " x_{n+1} = x_n - \\frac{f(x_n)}{f^\\prime(x_n)}\n",
+ "$$\n",
+ "\n",
+ "But we search for $f^\\prime(x) = 0$\n",
+ "\n",
+ "$$\n",
+ " x_{n+1} = x_n - \\frac{f^{\\prime}(x_n)}{f^{\\prime\\prime}(x_n)}\n",
+ " $$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "426a3238",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gradL = grad(L)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "21cd42fa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gradL2 = grad(gradL)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "4c5b3265",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def min_newton(x): return x - gradL(x) / gradL2(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "4dd342ac",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = np.linspace(3.0, 5.0, num=50)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "235d15b4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_newton = vmap(min_newton)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "9a88cd8b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(50):\n",
+ " domain = f_newton(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "529dbb0e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mins = minfun(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "65aa7311",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "arglist = nanargmin(minimums)\n",
+ "argmin = domain[arglist]\n",
+ "minimum = minimums[arglist]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "c0327e0e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the minimum is 7.794247627258301 the arg min is 4.499999523162842\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"the minimum is {minimum} the arg min is {argmin}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1180ee4c",
+ "metadata": {},
+ "source": [
+ "# Multivariable Optimization\n",
+ "\n",
+ "$$\n",
+ " X_{n+1} = X_n - 0.01 \\nabla f(X_n)\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "2a968948",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def paraboloid(x): return (x[0] * x[1] - 2) ** 2 + (x[1] - 3) ** 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "f893af70",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_min = vmap(paraboloid)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "id": "0b035649",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "J = jacfwd(paraboloid)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "edd62878",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def jacobian_min(x): return x - 0.01 * J(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "81c302fd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "id": "70d0555a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "v_jacobian_min = vmap(jacobian_min)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "id": "d53b2818",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(150):\n",
+ " domain = v_jacobian_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "id": "93c3cfdf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mins = f_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "c2f9afa2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "arglist = nanargmin(mins)\n",
+ "argmin = domain[arglist]\n",
+ "minimum = minimums[arglist]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "91ebaa54",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the minimum is 7.794247627258301 the arg min is [-0.54916656 0.60931224]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"the minimum is {minimum} the arg min is {argmin}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e5e16d79",
+ "metadata": {},
+ "source": [
+ "# The Hessian\n",
+ "\n",
+ "Similar to the Newton method:\n",
+ "$$\n",
+ " X_{n+1} = X_n - H^{-1}(X_n) \\nabla f(X_n)\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "id": "0c11e322",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def hessian(f): return jacfwd(jacrev(f))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "id": "7a5d3955",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "H = hessian(paraboloid)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "id": "8b32404d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def hessian_min(x): return x - 0.01 * linalg.inv(H(x)) @ J(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "id": "050f46e5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "id": "cba5bd2b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "v_hessian_min = vmap(hessian_min)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "id": "27b219df",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(150):\n",
+ " domain = v_hessian_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 78,
+ "id": "dfdf027b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mins = f_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "id": "db074a43",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the minimum is 7.794247627258301 the arg min is [-0.74005073 -0.00774691]\n"
+ ]
+ }
+ ],
+ "source": [
+ "arglist = nanargmin(mins)\n",
+ "argmin = domain[arglist]\n",
+ "minimum = minimums[arglist]\n",
+ "print(f\"the minimum is {minimum} the arg min is {argmin}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e2cc6b7d",
+ "metadata": {},
+ "source": [
+ "# Mutlivariable Constrained Optimization\n",
+ "\n",
+ "Lagrangian multipliers:\n",
+ "\n",
+ "$$\n",
+ "\\begin{align}\n",
+ "\\nabla f = \\lambda \\nabla g\\\\\n",
+ "\\nabla f - \\lambda \\nabla g = 0\\\\\n",
+ "\\nabla (f - \\lambda g) = 0\\\\\n",
+ "\\nabla (\\mathcal{L}) = 0\\\\\n",
+ "\\end{align}\n",
+ "$$\n",
+ "\n",
+ "Our iterative approach will be:\n",
+ "\n",
+ "$$\n",
+ "X_{n+1} = X_n - \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "id": "a8ff98bc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def f(x): return 4 * (x[0] ** 2) * x[1]\n",
+ "def g(x): return x[0]**2 + x[1]**2 - 3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 81,
+ "id": "ba2ce45c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_min = vmap(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "id": "c0d03ac0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def lagrange(l): return f(l[0:2]) - l[2]*g(l[0:2])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "id": "fc1d5af3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "L = jacfwd(lagrange)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "id": "60f8134b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gradL = jacfwd(L)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "id": "ef6ddc53",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def lagrange_min(l): return l - linalg.inv(gradL(l)) @ L(l)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 86,
+ "id": "c67b6785",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = random.uniform(key, shape=(50,3), dtype=\"float32\", minval=-5.0, maxval=5.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 87,
+ "id": "d61873ab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "v_lagrange_min = vmap(lagrange_min)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "id": "73539993",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(150):\n",
+ " domain = v_lagrange_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "id": "7f4d91c7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mins = f_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "id": "38f56026",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the minimum is 7.8554582595825195 the arg min is [-1.4142135 -1. -4. ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "arglist = nanargmin(mins)\n",
+ "argmin = domain[arglist]\n",
+ "minimum = minimums[arglist]\n",
+ "print(f\"the minimum is {minimum} the arg min is {argmin}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "380c75c3",
+ "metadata": {},
+ "source": [
+ "# Three Variable Multivariable Constrained Optimization\n",
+ "$$\n",
+ "\\begin{align}\n",
+ "f(x, y, z)& = xyz\\\\\n",
+ "g(x, y, z)& = 2xy + 2yz + 2xz - 64\\\\\n",
+ "\\end{align}\n",
+ "$$\n",
+ "\n",
+ "Our iterative approach will be:\n",
+ "\n",
+ "$$\n",
+ "X_{n+1} = X_n - 0.1 \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "id": "820312ab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def f(x): return x[0]*x[1]*x[2]\n",
+ "def g(x): return 2*x[0]*x[1] + 2*x[1]*x[2] + 2*x[0]*x[2] - 64\n",
+ "\n",
+ "f_min = vmap(f)\n",
+ "\n",
+ "def lagrange(l): return f(l[0:3]) - l[3]*g(l[0:3])\n",
+ "\n",
+ "L = jacfwd(lagrange)\n",
+ "gradL = jacfwd(L)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "id": "61f9db1e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def lagrangian_min(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "id": "9e8d959f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = random.uniform(key, shape=(50, 4), dtype=\"float32\", minval = 0, maxval=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "id": "c3b8c4fa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "v_lagrangian_min = vmap(lagrangian_min)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 102,
+ "id": "8e83319e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(200):\n",
+ " domain = v_lagrangian_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 103,
+ "id": "d5f28369",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The minimum is 34.83720016479492, the argmin is (3.2659873962402344,3.2659854888916016,3.2659873962402344), the lagrangian is 0.8164968490600586\n"
+ ]
+ }
+ ],
+ "source": [
+ "maximums = f_min(domain)\n",
+ "\n",
+ "arglist = nanargmax(maximums)\n",
+ "argmin = domain[arglist]\n",
+ "minimum = maximums[arglist]\n",
+ "\n",
+ "print(\"The minimum is {}, the argmin is ({},{},{}), the lagrangian is {}\".format(minimum,\n",
+ " argmin[0], argmin[1], argmin[2], argmin[3]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0bf20ea5",
+ "metadata": {},
+ "source": [
+ "# Multivariable MultiConstrained Optimization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e834f8ba",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}