From dcdb2f4e11962b5f82c184e4deb1b9c0d51fdf95 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 22 Mar 2022 20:51:32 +0100 Subject: build: update to jax with gpu support --- poetry.lock | 21 ++++++--------------- pyproject.toml | 4 ++-- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8738dce..223f0f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -298,7 +298,7 @@ tpu = ["jaxlib (==0.3.2)", "libtpu-nightly (==0.1.dev20220315)", "requests"] [[package]] name = "jaxlib" -version = "0.3.2" +version = "0.3.2+cuda11.cudnn82" description = "XLA library for JAX" category = "main" optional = false @@ -310,6 +310,9 @@ flatbuffers = ">=1.12,<3.0" numpy = ">=1.19" scipy = "*" +[package.source] +type = "url" +url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.2+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl" [[package]] name = "jedi" version = "0.18.1" @@ -996,7 +999,7 @@ notebook = ">=4.4.1" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "c361c72b1e4faba684898b8e62490234f51e7131e44a6e09bf67511e7aa7698d" +content-hash = "64a828dee5110970cb741ce86a5ab71a9c94f625d9585afe3ee79286a7a6d940" [metadata.files] absl-py = [ @@ -1176,19 +1179,7 @@ ipywidgets = [ jax = [ {file = "jax-0.3.4.tar.gz", hash = "sha256:f26854ea8f5449493fadc2f70936c57b7e07b5eb91f81f710748e93e97ec785a"}, ] -jaxlib = [ - {file = "jaxlib-0.3.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:338daac14405104b1d193ff8b7ab1e91eb8cdfdf6329cd189f6d041e29a57355"}, - {file = "jaxlib-0.3.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:6bc65f76c96dbeb8f2d5dc6a0fb7c0ab39642bec3e6fc6de00eb1f2a41401b91"}, - {file = "jaxlib-0.3.2-cp310-none-manylinux2010_x86_64.whl", hash = "sha256:b8314f9ee62642d899d286456e83ad4de819e539f6100383a83d532b0ef0b97e"}, - {file = "jaxlib-0.3.2-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0a425b09836ed682dcdfc88e79595cbcd2987b55ac6219c382826f9ce4a7eb47"}, - {file = "jaxlib-0.3.2-cp37-none-manylinux2010_x86_64.whl", hash = "sha256:8c1a670d868ffc0df217ff0faffca8c1855c5c844773f771ef70a68725ccff89"}, - {file = "jaxlib-0.3.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:d1149dd4bb1f66d0a3e76337a59dcbd864ca98cabdc559e548b95fc874368200"}, - {file = "jaxlib-0.3.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:21506c98bdd50e3feb21fac3d245a632ff71eed76ceac7d00b5e9ad9e9ce36f4"}, - {file = "jaxlib-0.3.2-cp38-none-manylinux2010_x86_64.whl", hash = "sha256:3a42055bea0737d1559c79c0307f6510722a9b26b3c4b93f8efc043ec9d02502"}, - {file = "jaxlib-0.3.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:455ed3749a50b105cbd4010f3c2b3e11378ce6ca828d90a3e69f6f7fc2af9bc9"}, - {file = "jaxlib-0.3.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a9db5f072b463f5ebc1d792ea41a8acc0821876f90c8fdf7c3d4e38e09b7ca64"}, - {file = "jaxlib-0.3.2-cp39-none-manylinux2010_x86_64.whl", hash = "sha256:159ef0ec9afa7d0d6b4d1ef18aa49825db15173a86fe7bb879c2ecbbace4d9a9"}, -] +jaxlib = [] jedi = [ {file = "jedi-0.18.1-py2.py3-none-any.whl", hash = "sha256:637c9635fcf47945ceb91cd7f320234a7be540ded6f3e99a50cb6febdfd1ba8d"}, {file = "jedi-0.18.1.tar.gz", hash = "sha256:74137626a64a99c8eb6ae5832d99b3bdd7d29a3850fe2aa80a4126b2a7d949ab"}, diff --git a/pyproject.toml b/pyproject.toml index d3ea39d..9a2eed3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,9 @@ authors = ["Gustaf Rydholm "] [tool.poetry.dependencies] python = "^3.9" -jax = "^0.3.4" jupyter = "^1.0.0" -jaxlib = "^0.3.2" +jaxlib = {url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.2+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"} +jax = "^0.3.4" [tool.poetry.dev-dependencies] pytest = "^5.2" -- cgit v1.2.3-70-g09d2