# Solving Optimization Problems with JAX
Source: https://medium.com/swlh/solving-optimization-problems-with-jax-98376508bd4f

In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
from jax import jacfwd, jacrev
from jax.numpy import linalg
from numpy import nanargmin, nanargmax

RuntimeError: jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0.

In [None]:
key = random.PRNGKey(42)

In [3]:
def f(x): return 3 * x[0] ** 2

In [4]:
gradf = grad(f)
gradf(np.array([2.0]))

DeviceArray([12.], dtype=float32)

In [5]:
def circle(x): return x[0] ** 2 + x[1] ** 2

In [6]:
J = jacfwd(circle)

In [7]:
J(np.array([1.0, 2.0]))

DeviceArray([2., 4.], dtype=float32)

In [14]:
def hessian(f): return jacfwd(jacrev(f))

In [15]:
H = hessian(circle)

In [16]:
H(np.array([1.0, 2.0]))

DeviceArray([[2., 0.],
 [0., 2.]], dtype=float32)

In [17]:
# Paper problem

In [18]:
def y(x): return ((x * np.sqrt(12 * x - 36)) / (2 * (x - 3)))
def L(x): return np.sqrt(x ** 2 + y(x) ** 2)

In [19]:
gradL = grad(L)

In [20]:
def minGD(x): return x - 0.01 * gradL(x)

In [21]:
domain = np.linspace(3.0, 5.0, num=50)

In [22]:
vfunGD = vmap(minGD)

In [25]:
for epoch in range(50):
 domain = vfunGD(domain)

In [28]:
minfun = vmap(L)
minimums = minfun(domain)

In [29]:
minimums

DeviceArray([ nan, 8.337169 , 7.8126216, 7.8554583, 7.8657427,
 7.8686085, 7.869442 , 7.869495 , 7.8691173, 7.8684244,
 7.867449 , 7.8662043, 7.8646903, 7.8629084, 7.860858 ,
 7.8585405, 7.8559637, 7.853135 , 7.8500714, 7.8467846,
 7.843305 , 7.8396564, 7.8358707, 7.8319883, 7.828048 ,
 7.8240952, 7.8201776, 7.8163476, 7.812655 , 7.809155 ,
 7.805899 , 7.8029428, 7.800338 , 7.7981334, 7.7963796,
 7.795119 , 7.7943954, 7.7942476, 7.7947106, 7.795813 ,
 7.797586 , 7.8000493, 7.8032246, 7.807126 , 7.811767 ,
 7.8171563, 7.8233 , 7.830198 , 7.837856 , 7.846267 ], dtype=float32)

In [30]:
arglist = nanargmin(minimums)

In [31]:
arglist

37

In [32]:
argmin = domain[arglist]

In [33]:
minimum = minimums[arglist]

In [34]:
print(f"the minimum is {minimum} the arg min is {argmin}")

the minimum is 7.794247627258301 the arg min is 4.505731105804443


# Newton's method

Solve for where $f(x) = 0$

$$
 x_{n+1} = x_n - \frac{f(x_n)}{f^\prime(x_n)}
$$

But we search for $f^\prime(x) = 0$

$$
 x_{n+1} = x_n - \frac{f^{\prime}(x_n)}{f^{\prime\prime}(x_n)}
 $$

In [35]:
gradL = grad(L)

In [36]:
gradL2 = grad(gradL)

In [39]:
def min_newton(x): return x - gradL(x) / gradL2(x)

In [38]:
domain = np.linspace(3.0, 5.0, num=50)

In [40]:
f_newton = vmap(min_newton)

In [41]:
for epoch in range(50):
 domain = f_newton(domain)

In [43]:
mins = minfun(domain)

In [44]:
arglist = nanargmin(minimums)
argmin = domain[arglist]
minimum = minimums[arglist]

In [45]:
print(f"the minimum is {minimum} the arg min is {argmin}")

the minimum is 7.794247627258301 the arg min is 4.499999523162842


# Multivariable Optimization

$$
 X_{n+1} = X_n - 0.01 \nabla f(X_n)
$$

In [48]:
def paraboloid(x): return (x[0] * x[1] - 2) ** 2 + (x[1] - 3) ** 2

In [49]:
f_min = vmap(paraboloid)

In [50]:
J = jacfwd(paraboloid)

In [51]:
def jacobian_min(x): return x - 0.01 * J(x)

In [69]:
domain = random.uniform(key, shape=(50, 2), dtype="float32", minval=-5.0, maxval=5.0)

In [70]:
v_jacobian_min = vmap(jacobian_min)

In [71]:
for epoch in range(150):
 domain = v_jacobian_min(domain)

In [72]:
mins = f_min(domain)

In [73]:
arglist = nanargmin(mins)
argmin = domain[arglist]
minimum = minimums[arglist]

In [74]:
print(f"the minimum is {minimum} the arg min is {argmin}")

the minimum is 7.794247627258301 the arg min is [-0.54916656 0.60931224]


# The Hessian

Similar to the Newton method:
$$
 X_{n+1} = X_n - H^{-1}(X_n) \nabla f(X_n)
$$

In [60]:
def hessian(f): return jacfwd(jacrev(f))

In [61]:
H = hessian(paraboloid)

In [62]:
def hessian_min(x): return x - 0.01 * linalg.inv(H(x)) @ J(x)

In [75]:
domain = random.uniform(key, shape=(50, 2), dtype="float32", minval=-5.0, maxval=5.0)

In [76]:
v_hessian_min = vmap(hessian_min)

In [77]:
for epoch in range(150):
 domain = v_hessian_min(domain)

In [78]:
mins = f_min(domain)

In [79]:
arglist = nanargmin(mins)
argmin = domain[arglist]
minimum = minimums[arglist]
print(f"the minimum is {minimum} the arg min is {argmin}")

the minimum is 7.794247627258301 the arg min is [-0.74005073 -0.00774691]


# Mutlivariable Constrained Optimization

Lagrangian multipliers:

$$
\begin{align}
\nabla f = \lambda \nabla g\\
\nabla f - \lambda \nabla g = 0\\
\nabla (f - \lambda g) = 0\\
\nabla (\mathcal{L}) = 0\\
\end{align}
$$

Our iterative approach will be:

$$
X_{n+1} = X_n - \nabla^2\mathcal{L}^{-1}(X_n)\nabla\mathcal{L}(X_n)
$$

In [80]:
def f(x): return 4 * (x[0] ** 2) * x[1]
def g(x): return x[0]**2 + x[1]**2 - 3

In [81]:
f_min = vmap(f)

In [82]:
def lagrange(l): return f(l[0:2]) - l[2]*g(l[0:2])

In [83]:
L = jacfwd(lagrange)

In [84]:
gradL = jacfwd(L)

In [85]:
def lagrange_min(l): return l - linalg.inv(gradL(l)) @ L(l)

In [86]:
domain = random.uniform(key, shape=(50,3), dtype="float32", minval=-5.0, maxval=5.0)

In [87]:
v_lagrange_min = vmap(lagrange_min)

In [88]:
for epoch in range(150):
 domain = v_lagrange_min(domain)

In [89]:
mins = f_min(domain)

In [91]:
arglist = nanargmin(mins)
argmin = domain[arglist]
minimum = minimums[arglist]
print(f"the minimum is {minimum} the arg min is {argmin}")

the minimum is 7.8554582595825195 the arg min is [-1.4142135 -1. -4. ]


# Three Variable Multivariable Constrained Optimization
$$
\begin{align}
f(x, y, z)& = xyz\\
g(x, y, z)& = 2xy + 2yz + 2xz - 64\\
\end{align}
$$

Our iterative approach will be:

$$
X_{n+1} = X_n - 0.1 \nabla^2\mathcal{L}^{-1}(X_n)\nabla\mathcal{L}(X_n)
$$

In [98]:
def f(x): return x[0]*x[1]*x[2]
def g(x): return 2*x[0]*x[1] + 2*x[1]*x[2] + 2*x[0]*x[2] - 64

f_min = vmap(f)

def lagrange(l): return f(l[0:3]) - l[3]*g(l[0:3])

L = jacfwd(lagrange)
gradL = jacfwd(L)

In [99]:
def lagrangian_min(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)

In [100]:
domain = random.uniform(key, shape=(50, 4), dtype="float32", minval = 0, maxval=10)

In [101]:
v_lagrangian_min = vmap(lagrangian_min)

In [102]:
for epoch in range(200):
 domain = v_lagrangian_min(domain)

In [103]:
maximums = f_min(domain)

arglist = nanargmax(maximums)
argmin = domain[arglist]
minimum = maximums[arglist]

print("The minimum is {}, the argmin is ({},{},{}), the lagrangian is {}".format(minimum,
 argmin[0], argmin[1], argmin[2], argmin[3]))

The minimum is 34.83720016479492, the argmin is (3.2659873962402344,3.2659854888916016,3.2659873962402344), the lagrangian is 0.8164968490600586


# Multivariable MultiConstrained Optimization