From e8663d381446c8c1024ac92af2760ccfe917d76b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 22 Mar 2022 20:51:14 +0100 Subject: fix: check if jax is using gpu --- jax_optimization_tutorial/basic-optimization.ipynb | 31 ++++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/jax_optimization_tutorial/basic-optimization.ipynb b/jax_optimization_tutorial/basic-optimization.ipynb index eb6a339..ddac210 100644 --- a/jax_optimization_tutorial/basic-optimization.ipynb +++ b/jax_optimization_tutorial/basic-optimization.ipynb @@ -14,7 +14,24 @@ "execution_count": 1, "id": "02e5a88f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m grad, jit, vmap\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m random\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/__init__.py:37\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _cloud_tpu_init\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# Confusingly there are two things named \"config\": the module and the class.\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# We want the exported object to be the class, so we first import the module\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# to make sure a later import doesn't overwrite the class.\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config \u001b[38;5;28;01mas\u001b[39;00m _config_module\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _config_module\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 41\u001b[0m config \u001b[38;5;28;01mas\u001b[39;00m config,\n\u001b[1;32m 42\u001b[0m enable_checks \u001b[38;5;28;01mas\u001b[39;00m enable_checks,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 56\u001b[0m transfer_guard_device_to_host \u001b[38;5;28;01mas\u001b[39;00m transfer_guard_device_to_host,\n\u001b[1;32m 57\u001b[0m )\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/config.py:18\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Copyright 2018 Google LLC\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/config.py:27\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, List, Callable, NamedTuple, Iterator, Optional\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m lib\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m jax_jit\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mxla_extension_version \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m58\u001b[39m:\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:87\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _jaxlib_version\n\u001b[1;32m 86\u001b[0m version_str \u001b[38;5;241m=\u001b[39m jaxlib\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39m__version__\n\u001b[0;32m---> 87\u001b[0m version \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_jaxlib_version\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[43m \u001b[49m\u001b[43mjax_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[43m \u001b[49m\u001b[43mjaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjaxlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[43m \u001b[49m\u001b[43mminimum_jaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minimum_jaxlib_version\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# Before importing any C compiled modules from jaxlib, first import the CPU\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;66;03m# feature guard module to verify that jaxlib was compiled in a way that only\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# uses instructions that are present on this machine.\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjaxlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:76\u001b[0m, in \u001b[0;36mcheck_jaxlib_version\u001b[0;34m(jax_version, jaxlib_version, minimum_jaxlib_version)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m<\u001b[39m _minimum_jaxlib_version:\n\u001b[1;32m 74\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib is version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but this version \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mof jax requires version >= \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mminimum_jaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m>\u001b[39m _jax_version:\n\u001b[1;32m 79\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is newer than and \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mincompatible with jax version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjax_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mupdate your jax and/or jaxlib packages.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mRuntimeError\u001b[0m: jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0." + ] + } + ], "source": [ "import jax.numpy as np\n", "from jax import grad, jit, vmap\n", @@ -26,18 +43,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "afe9f53c", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "key = random.PRNGKey(42)" ] -- cgit v1.2.3-70-g09d2