summaryrefslogtreecommitdiff
path: root/jax_optimization_tutorial/basic-optimization.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'jax_optimization_tutorial/basic-optimization.ipynb')
-rw-r--r--jax_optimization_tutorial/basic-optimization.ipynb31
1 files changed, 20 insertions, 11 deletions
diff --git a/jax_optimization_tutorial/basic-optimization.ipynb b/jax_optimization_tutorial/basic-optimization.ipynb
index eb6a339..ddac210 100644
--- a/jax_optimization_tutorial/basic-optimization.ipynb
+++ b/jax_optimization_tutorial/basic-optimization.ipynb
@@ -14,7 +14,24 @@
"execution_count": 1,
"id": "02e5a88f",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "RuntimeError",
+ "evalue": "jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m grad, jit, vmap\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m random\n",
+ "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/__init__.py:37\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _cloud_tpu_init\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# Confusingly there are two things named \"config\": the module and the class.\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# We want the exported object to be the class, so we first import the module\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# to make sure a later import doesn't overwrite the class.\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config \u001b[38;5;28;01mas\u001b[39;00m _config_module\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _config_module\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 41\u001b[0m config \u001b[38;5;28;01mas\u001b[39;00m config,\n\u001b[1;32m 42\u001b[0m enable_checks \u001b[38;5;28;01mas\u001b[39;00m enable_checks,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 56\u001b[0m transfer_guard_device_to_host \u001b[38;5;28;01mas\u001b[39;00m transfer_guard_device_to_host,\n\u001b[1;32m 57\u001b[0m )\n",
+ "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/config.py:18\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Copyright 2018 Google LLC\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config\n",
+ "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/config.py:27\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, List, Callable, NamedTuple, Iterator, Optional\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m lib\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m jax_jit\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mxla_extension_version \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m58\u001b[39m:\n",
+ "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:87\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _jaxlib_version\n\u001b[1;32m 86\u001b[0m version_str \u001b[38;5;241m=\u001b[39m jaxlib\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39m__version__\n\u001b[0;32m---> 87\u001b[0m version \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_jaxlib_version\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[43m \u001b[49m\u001b[43mjax_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[43m \u001b[49m\u001b[43mjaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjaxlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[43m \u001b[49m\u001b[43mminimum_jaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minimum_jaxlib_version\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# Before importing any C compiled modules from jaxlib, first import the CPU\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;66;03m# feature guard module to verify that jaxlib was compiled in a way that only\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# uses instructions that are present on this machine.\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjaxlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m\n",
+ "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:76\u001b[0m, in \u001b[0;36mcheck_jaxlib_version\u001b[0;34m(jax_version, jaxlib_version, minimum_jaxlib_version)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m<\u001b[39m _minimum_jaxlib_version:\n\u001b[1;32m 74\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib is version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but this version \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mof jax requires version >= \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mminimum_jaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m>\u001b[39m _jax_version:\n\u001b[1;32m 79\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is newer than and \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mincompatible with jax version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjax_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mupdate your jax and/or jaxlib packages.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0."
+ ]
+ }
+ ],
"source": [
"import jax.numpy as np\n",
"from jax import grad, jit, vmap\n",
@@ -26,18 +43,10 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "afe9f53c",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"key = random.PRNGKey(42)"
]