From 1f459ba19422593de325983040e176f97cf4ffc0 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 20 Aug 2020 22:18:35 +0200 Subject: A lot of stuff working :D. ResNet implemented! --- .flake8 | 2 +- README.md | 2 +- poetry.lock | 384 ++++++++----- pyproject.toml | 5 +- src/notebooks/00-testing-stuff-out.ipynb | 639 ++++++++++++++++++++- src/notebooks/01-look-at-emnist.ipynb | 25 +- src/text_recognizer/datasets/emnist_dataset.py | 43 +- .../datasets/emnist_lines_dataset.py | 9 +- src/text_recognizer/models/base.py | 45 +- src/text_recognizer/models/character_model.py | 8 +- src/text_recognizer/networks/__init__.py | 3 +- src/text_recognizer/networks/lenet.py | 17 +- src/text_recognizer/networks/misc.py | 20 +- src/text_recognizer/networks/mlp.py | 18 +- src/text_recognizer/networks/residual_network.py | 314 ++++++++++ .../CharacterModel_EmnistDataset_LeNet_weights.pt | Bin 14485310 -> 14485362 bytes .../CharacterModel_EmnistDataset_MLP_weights.pt | Bin 1704174 -> 11625484 bytes ...rModel_EmnistDataset_ResidualNetwork_weights.pt | Bin 0 -> 28654593 bytes src/training/callbacks/__init__.py | 19 - src/training/callbacks/base.py | 240 -------- src/training/callbacks/early_stopping.py | 107 ---- src/training/callbacks/lr_schedulers.py | 97 ---- src/training/callbacks/wandb_callbacks.py | 93 --- src/training/experiments/sample_experiment.yml | 37 +- src/training/population_based_training/__init__.py | 1 - .../population_based_training.py | 1 - src/training/prepare_experiments.py | 6 +- src/training/run_experiment.py | 19 +- src/training/train.py | 249 -------- src/training/trainer/__init__.py | 2 + src/training/trainer/callbacks/__init__.py | 21 + src/training/trainer/callbacks/base.py | 248 ++++++++ src/training/trainer/callbacks/early_stopping.py | 108 ++++ src/training/trainer/callbacks/lr_schedulers.py | 97 ++++ src/training/trainer/callbacks/progress_bar.py | 61 ++ src/training/trainer/callbacks/wandb_callbacks.py | 93 +++ .../trainer/population_based_training/__init__.py | 1 + .../population_based_training.py | 1 + src/training/trainer/train.py | 216 +++++++ src/training/trainer/util.py | 19 + src/training/util.py | 19 - 41 files changed, 2185 insertions(+), 1104 deletions(-) create mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt delete mode 100644 src/training/callbacks/__init__.py delete mode 100644 src/training/callbacks/base.py delete mode 100644 src/training/callbacks/early_stopping.py delete mode 100644 src/training/callbacks/lr_schedulers.py delete mode 100644 src/training/callbacks/wandb_callbacks.py delete mode 100644 src/training/population_based_training/__init__.py delete mode 100644 src/training/population_based_training/population_based_training.py delete mode 100644 src/training/train.py create mode 100644 src/training/trainer/__init__.py create mode 100644 src/training/trainer/callbacks/__init__.py create mode 100644 src/training/trainer/callbacks/base.py create mode 100644 src/training/trainer/callbacks/early_stopping.py create mode 100644 src/training/trainer/callbacks/lr_schedulers.py create mode 100644 src/training/trainer/callbacks/progress_bar.py create mode 100644 src/training/trainer/callbacks/wandb_callbacks.py create mode 100644 src/training/trainer/population_based_training/__init__.py create mode 100644 src/training/trainer/population_based_training/population_based_training.py create mode 100644 src/training/trainer/train.py create mode 100644 src/training/trainer/util.py delete mode 100644 src/training/util.py diff --git a/.flake8 b/.flake8 index a27b644..c64d7ca 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] select = ANN,B,B9,BLK,C,D,DAR,E,F,I,S,W -ignore = E203,E501,W503,ANN101,F401,D202,S404 +ignore = E203,E501,W503,ANN101,ANN002,ANN003,F401,D202,S404,D107 max-line-length = 120 max-complexity = 10 application-import-names = text_recognizer,tests diff --git a/README.md b/README.md index 328cfda..844f2e0 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ TBC - [x] Implement wandb - [x] Implement lr scheduler as a callback - [x] Implement save checkpoint callback - - [ ] Implement TQDM progress bar (Low priority) + - [x] Implement TQDM progress bar (Low priority) - [ ] Check that dataset exists, otherwise download it form the web. Do this in run_experiment.py. - [x] Create repr func for data loaders - [ ] Be able to restart with lr scheduler (May skip this BS) diff --git a/poetry.lock b/poetry.lock index f174cf1..89851d9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -23,6 +23,23 @@ optional = false python-versions = "*" version = "0.1.0" +[[package]] +category = "dev" +description = "The secure Argon2 password hashing algorithm." +name = "argon2-cffi" +optional = false +python-versions = "*" +version = "20.1.0" + +[package.dependencies] +cffi = ">=1.0.0" +six = "*" + +[package.extras] +dev = ["coverage (>=5.0.2)", "hypothesis", "pytest", "sphinx", "wheel", "pre-commit"] +docs = ["sphinx"] +tests = ["coverage (>=5.0.2)", "hypothesis", "pytest"] + [[package]] category = "main" description = "Atomic file writes." @@ -130,7 +147,7 @@ description = "When they're not builtins, they're boltons." name = "boltons" optional = false python-versions = "*" -version = "20.2.0" +version = "20.2.1" [[package]] category = "main" @@ -140,6 +157,17 @@ optional = false python-versions = "*" version = "2020.6.20" +[[package]] +category = "dev" +description = "Foreign Function Interface for Python calling C code." +name = "cffi" +optional = false +python-versions = "*" +version = "1.14.2" + +[package.dependencies] +pycparser = "*" + [[package]] category = "main" description = "Universal encoding detector for Python 2 and 3" @@ -166,7 +194,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" version = "0.4.3" [[package]] -category = "dev" +category = "main" description = "Updated configparser from Python 3.8 for Python 2.6+." name = "configparser" optional = false @@ -254,7 +282,7 @@ typing-inspect = "*" dev = ["coverage", "cuvner", "pytest", "tox", "versioneer", "black", "pylint", "pex", "bump2version", "docutils", "check-manifest", "readme-renderer", "pygments", "isort", "mypy", "pytest-sphinx", "towncrier", "marshmallow-union", "marshmallow-enum", "twine", "wheel"] [[package]] -category = "dev" +category = "main" description = "Python bindings for the docker credentials store API" name = "docker-pycreds" optional = false @@ -418,7 +446,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" version = "0.18.2" [[package]] -category = "dev" +category = "main" description = "Git Object Database" name = "gitdb" optional = false @@ -429,7 +457,7 @@ version = "4.0.5" smmap = ">=3.0.1,<4" [[package]] -category = "dev" +category = "main" description = "Python Git Library" name = "gitpython" optional = false @@ -457,7 +485,7 @@ six = ">=1.7" test = ["mock (>=2.0.0)", "pytest (<5.0)"] [[package]] -category = "dev" +category = "main" description = "GraphQL client for Python" name = "gql" optional = false @@ -471,7 +499,7 @@ requests = ">=2.12,<3" six = ">=1.10.0" [[package]] -category = "dev" +category = "main" description = "GraphQL implementation for Python" name = "graphql-core" optional = false @@ -566,8 +594,8 @@ category = "dev" description = "IPython: Productive Interactive Computing" name = "ipython" optional = false -python-versions = ">=3.6" -version = "7.16.1" +python-versions = ">=3.7" +version = "7.17.0" [package.dependencies] appnope = "*" @@ -796,9 +824,10 @@ description = "Python plotting package" name = "matplotlib" optional = false python-versions = ">=3.6" -version = "3.3.0" +version = "3.3.1" [package.dependencies] +certifi = ">=2020.06.20" cycler = ">=0.10" kiwisolver = ">=1.0.1" numpy = ">=1.15" @@ -961,10 +990,11 @@ description = "A web-based notebook environment for interactive computing" name = "notebook" optional = false python-versions = ">=3.5" -version = "6.0.3" +version = "6.1.3" [package.dependencies] Send2Trash = "*" +argon2-cffi = "*" ipykernel = "*" ipython-genutils = "*" jinja2 = "*" @@ -974,12 +1004,13 @@ nbconvert = "*" nbformat = "*" prometheus-client = "*" pyzmq = ">=17" -terminado = ">=0.8.1" +terminado = ">=0.8.3" tornado = ">=5.0" traitlets = ">=4.2.1" [package.extras] -test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-exclude", "selenium", "pytest", "pytest-cov", "nose-exclude"] +docs = ["sphinx", "nbsphinx", "sphinxcontrib-github-alt"] +test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-exclude", "selenium", "pytest", "pytest-cov", "requests-unixsocket"] [[package]] category = "main" @@ -990,7 +1021,7 @@ python-versions = ">=3.6" version = "1.19.1" [[package]] -category = "dev" +category = "main" description = "Python Bindings for the NVIDIA Management Library" name = "nvidia-ml-py3" optional = false @@ -1002,11 +1033,11 @@ category = "main" description = "Wrapper package for OpenCV python bindings." name = "opencv-python" optional = false -python-versions = "*" -version = "4.3.0.36" +python-versions = ">=3.5" +version = "4.4.0.42" [package.dependencies] -numpy = ">=1.11.1" +numpy = ">=1.13.1" [[package]] category = "main" @@ -1048,7 +1079,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" version = "0.8.0" [[package]] -category = "dev" +category = "main" description = "File system general utilities" name = "pathtools" optional = false @@ -1119,7 +1150,7 @@ version = "0.8.0" twisted = ["twisted"] [[package]] -category = "dev" +category = "main" description = "Promises/A+ implementation for Python" name = "promise" optional = false @@ -1138,13 +1169,13 @@ description = "Library for building powerful interactive command lines in Python name = "prompt-toolkit" optional = false python-versions = ">=3.6.1" -version = "3.0.5" +version = "3.0.6" [package.dependencies] wcwidth = "*" [[package]] -category = "dev" +category = "main" description = "Cross-platform lib for process and system monitoring in Python." name = "psutil" optional = false @@ -1179,6 +1210,14 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" version = "2.6.0" +[[package]] +category = "dev" +description = "C parser in Python" +name = "pycparser" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "2.20" + [[package]] category = "main" description = "Python docstring style checker" @@ -1257,7 +1296,7 @@ description = "Pytest plugin for measuring coverage." name = "pytest-cov" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "2.10.0" +version = "2.10.1" [package.dependencies] coverage = ">=4.4" @@ -1298,7 +1337,7 @@ marker = "python_version == \"3.7\"" name = "pytype" optional = false python-versions = "<3.9,>=3.5" -version = "2020.7.24" +version = "2020.8.10" [package.dependencies] attrs = "*" @@ -1335,7 +1374,7 @@ python-versions = "*" version = "0.5.7" [[package]] -category = "dev" +category = "main" description = "YAML parser and emitter for Python" name = "pyyaml" optional = false @@ -1348,7 +1387,7 @@ description = "Python bindings for 0MQ" name = "pyzmq" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*" -version = "19.0.1" +version = "19.0.2" [[package]] category = "dev" @@ -1452,12 +1491,12 @@ python-versions = "*" version = "1.5.0" [[package]] -category = "dev" +category = "main" description = "Python client for Sentry (https://sentry.io)" name = "sentry-sdk" optional = false python-versions = "*" -version = "0.16.2" +version = "0.16.5" [package.dependencies] certifi = "*" @@ -1479,7 +1518,7 @@ sqlalchemy = ["sqlalchemy (>=1.2)"] tornado = ["tornado (>=5)"] [[package]] -category = "dev" +category = "main" description = "A generator library for concise, unambiguous and URL-safe UUIDs." name = "shortuuid" optional = false @@ -1495,7 +1534,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" version = "1.15.0" [[package]] -category = "dev" +category = "main" description = "A pure Python implementation of a sliding window memory map manager" name = "smmap" optional = false @@ -1516,7 +1555,7 @@ description = "Python documentation generator" name = "sphinx" optional = false python-versions = ">=3.5" -version = "3.1.2" +version = "3.2.1" [package.dependencies] Jinja2 = ">=2.3" @@ -1655,7 +1694,7 @@ python = "<3.8" version = ">=1.7.0" [[package]] -category = "dev" +category = "main" description = "A backport of the subprocess module from Python 3 for use on 2.x." name = "subprocess32" optional = false @@ -1699,8 +1738,8 @@ category = "main" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" name = "torch" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.6.0" -version = "1.5.1" +python-versions = ">=3.6.1" +version = "1.6.0" [package.dependencies] future = "*" @@ -1720,12 +1759,12 @@ description = "image and video datasets and models for torch deep learning" name = "torchvision" optional = false python-versions = "*" -version = "0.6.1" +version = "0.7.0" [package.dependencies] numpy = "*" pillow = ">=4.1.1" -torch = "1.5.1" +torch = "1.6.0" [package.extras] scipy = ["scipy"] @@ -1744,7 +1783,7 @@ description = "Fast, Extensible Progress Meter" name = "tqdm" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*" -version = "4.48.0" +version = "4.48.2" [package.extras] dev = ["py-make (>=0.1.0)", "twine", "argopt", "pydoc-markdown"] @@ -1819,12 +1858,12 @@ secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "pyOpenSSL (>=0 socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7,<2.0)"] [[package]] -category = "dev" +category = "main" description = "A CLI and library for interacting with the Weights and Biases API." name = "wandb" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "0.9.4" +version = "0.9.5" [package.dependencies] Click = ">=7.0" @@ -1849,7 +1888,7 @@ gcp = ["google-cloud-storage"] kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"] [[package]] -category = "dev" +category = "main" description = "Filesystem events monitoring" name = "watchdog" optional = false @@ -1931,7 +1970,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "90078f3c0d16b4e2e67b4f5ad22219067d8936dfca527d130ec0a179d2c55cb0" +content-hash = "8aea4adc15683cb4bdd650d199725418e7778687d311bcb6a02db94fa05719f7" lock-version = "1.0" python-versions = "^3.7" @@ -1948,6 +1987,24 @@ appnope = [ {file = "appnope-0.1.0-py2.py3-none-any.whl", hash = "sha256:5b26757dc6f79a3b7dc9fab95359328d5747fcb2409d331ea66d0272b90ab2a0"}, {file = "appnope-0.1.0.tar.gz", hash = "sha256:8b995ffe925347a2138d7ac0fe77155e4311a0ea6d6da4f5128fe4b3cbe5ed71"}, ] +argon2-cffi = [ + {file = "argon2-cffi-20.1.0.tar.gz", hash = "sha256:d8029b2d3e4b4cea770e9e5a0104dd8fa185c1724a0f01528ae4826a6d25f97d"}, + {file = "argon2_cffi-20.1.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:6ea92c980586931a816d61e4faf6c192b4abce89aa767ff6581e6ddc985ed003"}, + {file = "argon2_cffi-20.1.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:05a8ac07c7026542377e38389638a8a1e9b78f1cd8439cd7493b39f08dd75fbf"}, + {file = "argon2_cffi-20.1.0-cp27-cp27m-win32.whl", hash = "sha256:0bf066bc049332489bb2d75f69216416329d9dc65deee127152caeb16e5ce7d5"}, + {file = "argon2_cffi-20.1.0-cp27-cp27m-win_amd64.whl", hash = "sha256:57358570592c46c420300ec94f2ff3b32cbccd10d38bdc12dc6979c4a8484fbc"}, + {file = "argon2_cffi-20.1.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:7d455c802727710e9dfa69b74ccaab04568386ca17b0ad36350b622cd34606fe"}, + {file = "argon2_cffi-20.1.0-cp35-abi3-manylinux1_x86_64.whl", hash = "sha256:b160416adc0f012fb1f12588a5e6954889510f82f698e23ed4f4fa57f12a0647"}, + {file = "argon2_cffi-20.1.0-cp35-cp35m-win32.whl", hash = "sha256:9bee3212ba4f560af397b6d7146848c32a800652301843df06b9e8f68f0f7361"}, + {file = "argon2_cffi-20.1.0-cp35-cp35m-win_amd64.whl", hash = "sha256:392c3c2ef91d12da510cfb6f9bae52512a4552573a9e27600bdb800e05905d2b"}, + {file = "argon2_cffi-20.1.0-cp36-cp36m-win32.whl", hash = "sha256:ba7209b608945b889457f949cc04c8e762bed4fe3fec88ae9a6b7765ae82e496"}, + {file = "argon2_cffi-20.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:da7f0445b71db6d3a72462e04f36544b0de871289b0bc8a7cc87c0f5ec7079fa"}, + {file = "argon2_cffi-20.1.0-cp37-abi3-macosx_10_6_intel.whl", hash = "sha256:cc0e028b209a5483b6846053d5fd7165f460a1f14774d79e632e75e7ae64b82b"}, + {file = "argon2_cffi-20.1.0-cp37-cp37m-win32.whl", hash = "sha256:18dee20e25e4be86680b178b35ccfc5d495ebd5792cd00781548d50880fee5c5"}, + {file = "argon2_cffi-20.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:6678bb047373f52bcff02db8afab0d2a77d83bde61cfecea7c5c62e2335cb203"}, + {file = "argon2_cffi-20.1.0-cp38-cp38-win32.whl", hash = "sha256:77e909cc756ef81d6abb60524d259d959bab384832f0c651ed7dcb6e5ccdbb78"}, + {file = "argon2_cffi-20.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:9dfd5197852530294ecb5795c97a823839258dfd5eb9420233c7cfedec2058f2"}, +] atomicwrites = [ {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, @@ -1982,13 +2039,43 @@ blessings = [ {file = "blessings-1.7.tar.gz", hash = "sha256:98e5854d805f50a5b58ac2333411b0482516a8210f23f43308baeb58d77c157d"}, ] boltons = [ - {file = "boltons-20.2.0-py2.py3-none-any.whl", hash = "sha256:1567a4ab991a52be4e9a778c6c433e50237018b322759146b07732b71e57ef67"}, - {file = "boltons-20.2.0.tar.gz", hash = "sha256:d367506c0b32042bb1ee3bf7899f2dcc8492dceb42ce3727b89e174d85bffe6e"}, + {file = "boltons-20.2.1-py2.py3-none-any.whl", hash = "sha256:3dd8a8e3c1886e7f7ba3422b50f55a66e1700161bf01b919d098e7d96dd2d9b6"}, + {file = "boltons-20.2.1.tar.gz", hash = "sha256:dd362291a460cc1e0c2e91cc6a60da3036ced77099b623112e8f833e6734bdc5"}, ] certifi = [ {file = "certifi-2020.6.20-py2.py3-none-any.whl", hash = "sha256:8fc0819f1f30ba15bdb34cceffb9ef04d99f420f68eb75d901e9560b8749fc41"}, {file = "certifi-2020.6.20.tar.gz", hash = "sha256:5930595817496dd21bb8dc35dad090f1c2cd0adfaf21204bf6732ca5d8ee34d3"}, ] +cffi = [ + {file = "cffi-1.14.2-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:da9d3c506f43e220336433dffe643fbfa40096d408cb9b7f2477892f369d5f82"}, + {file = "cffi-1.14.2-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:23e44937d7695c27c66a54d793dd4b45889a81b35c0751ba91040fe825ec59c4"}, + {file = "cffi-1.14.2-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:0da50dcbccd7cb7e6c741ab7912b2eff48e85af217d72b57f80ebc616257125e"}, + {file = "cffi-1.14.2-cp27-cp27m-win32.whl", hash = "sha256:76ada88d62eb24de7051c5157a1a78fd853cca9b91c0713c2e973e4196271d0c"}, + {file = "cffi-1.14.2-cp27-cp27m-win_amd64.whl", hash = "sha256:15a5f59a4808f82d8ec7364cbace851df591c2d43bc76bcbe5c4543a7ddd1bf1"}, + {file = "cffi-1.14.2-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:e4082d832e36e7f9b2278bc774886ca8207346b99f278e54c9de4834f17232f7"}, + {file = "cffi-1.14.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:57214fa5430399dffd54f4be37b56fe22cedb2b98862550d43cc085fb698dc2c"}, + {file = "cffi-1.14.2-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:6843db0343e12e3f52cc58430ad559d850a53684f5b352540ca3f1bc56df0731"}, + {file = "cffi-1.14.2-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:577791f948d34d569acb2d1add5831731c59d5a0c50a6d9f629ae1cefd9ca4a0"}, + {file = "cffi-1.14.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:8662aabfeab00cea149a3d1c2999b0731e70c6b5bac596d95d13f643e76d3d4e"}, + {file = "cffi-1.14.2-cp35-cp35m-win32.whl", hash = "sha256:837398c2ec00228679513802e3744d1e8e3cb1204aa6ad408b6aff081e99a487"}, + {file = "cffi-1.14.2-cp35-cp35m-win_amd64.whl", hash = "sha256:bf44a9a0141a082e89c90e8d785b212a872db793a0080c20f6ae6e2a0ebf82ad"}, + {file = "cffi-1.14.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:29c4688ace466a365b85a51dcc5e3c853c1d283f293dfcc12f7a77e498f160d2"}, + {file = "cffi-1.14.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:99cc66b33c418cd579c0f03b77b94263c305c389cb0c6972dac420f24b3bf123"}, + {file = "cffi-1.14.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:65867d63f0fd1b500fa343d7798fa64e9e681b594e0a07dc934c13e76ee28fb1"}, + {file = "cffi-1.14.2-cp36-cp36m-win32.whl", hash = "sha256:f5033952def24172e60493b68717792e3aebb387a8d186c43c020d9363ee7281"}, + {file = "cffi-1.14.2-cp36-cp36m-win_amd64.whl", hash = "sha256:7057613efefd36cacabbdbcef010e0a9c20a88fc07eb3e616019ea1692fa5df4"}, + {file = "cffi-1.14.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6539314d84c4d36f28d73adc1b45e9f4ee2a89cdc7e5d2b0a6dbacba31906798"}, + {file = "cffi-1.14.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:672b539db20fef6b03d6f7a14b5825d57c98e4026401fce838849f8de73fe4d4"}, + {file = "cffi-1.14.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95e9094162fa712f18b4f60896e34b621df99147c2cee216cfa8f022294e8e9f"}, + {file = "cffi-1.14.2-cp37-cp37m-win32.whl", hash = "sha256:b9aa9d8818c2e917fa2c105ad538e222a5bce59777133840b93134022a7ce650"}, + {file = "cffi-1.14.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e4b9b7af398c32e408c00eb4e0d33ced2f9121fd9fb978e6c1b57edd014a7d15"}, + {file = "cffi-1.14.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e613514a82539fc48291d01933951a13ae93b6b444a88782480be32245ed4afa"}, + {file = "cffi-1.14.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:9b219511d8b64d3fa14261963933be34028ea0e57455baf6781fe399c2c3206c"}, + {file = "cffi-1.14.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c0b48b98d79cf795b0916c57bebbc6d16bb43b9fc9b8c9f57f4cf05881904c75"}, + {file = "cffi-1.14.2-cp38-cp38-win32.whl", hash = "sha256:15419020b0e812b40d96ec9d369b2bc8109cc3295eac6e013d3261343580cc7e"}, + {file = "cffi-1.14.2-cp38-cp38-win_amd64.whl", hash = "sha256:12a453e03124069b6896107ee133ae3ab04c624bb10683e1ed1c1663df17c13c"}, + {file = "cffi-1.14.2.tar.gz", hash = "sha256:ae8f34d50af2c2154035984b8b5fc5d9ed63f32fe615646ab435b05b132ca91b"}, +] chardet = [ {file = "chardet-3.0.4-py2.py3-none-any.whl", hash = "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"}, {file = "chardet-3.0.4.tar.gz", hash = "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae"}, @@ -2186,8 +2273,8 @@ ipykernel = [ {file = "ipykernel-5.3.4.tar.gz", hash = "sha256:9b2652af1607986a1b231c62302d070bc0534f564c393a5d9d130db9abbbe89d"}, ] ipython = [ - {file = "ipython-7.16.1-py3-none-any.whl", hash = "sha256:2dbcc8c27ca7d3cfe4fcdff7f45b27f9a8d3edfa70ff8024a71c7a8eb5f09d64"}, - {file = "ipython-7.16.1.tar.gz", hash = "sha256:9f4fcb31d3b2c533333893b9172264e4821c1ac91839500f31bd43f2c59b3ccf"}, + {file = "ipython-7.17.0-py3-none-any.whl", hash = "sha256:5a8f159ca8b22b9a0a1f2a28befe5ad2b703339afb58c2ffe0d7c8d7a3af5999"}, + {file = "ipython-7.17.0.tar.gz", hash = "sha256:b70974aaa2674b05eb86a910c02ed09956a33f2dd6c71afc60f0b128a77e7f28"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -2292,27 +2379,24 @@ marshmallow = [ {file = "marshmallow-3.7.1.tar.gz", hash = "sha256:a2a5eefb4b75a3b43f05be1cca0b6686adf56af7465c3ca629e5ad8d1e1fe13d"}, ] matplotlib = [ - {file = "matplotlib-3.3.0-1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b0786ac32983191fcd9cc0230b4ec2f8b3c25dee9beca46ca506c5d6cc5c593d"}, - {file = "matplotlib-3.3.0-1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:f9753c6292d5a1fe46828feb38d1de1820e3ea109a5fea0b6ea1dca6e9d0b220"}, - {file = "matplotlib-3.3.0-1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6aa7ea00ad7d898704ffed46e83efd7ec985beba57f507c957979f080678b9ea"}, - {file = "matplotlib-3.3.0-1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:19cf4db0272da286863a50406f6430101af129f288c421b1a7f33ddfc8d0180f"}, - {file = "matplotlib-3.3.0-1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ebb6168c9330309b1f3360d36c481d8cd621a490cf2a69c9d6625b2a76777c12"}, - {file = "matplotlib-3.3.0-1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:695b4165520bdfe381d15a6db03778babb265fee7affdc43e169a881f3f329bc"}, - {file = "matplotlib-3.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc651261b7044ffc3b1e2f9af17b1ef4c6a12fc080b5a7353ef0b53a50be28"}, - {file = "matplotlib-3.3.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:e3868686f3023644523df486fc224b0af4349f3cdb933b0a71f261a574d7b65f"}, - {file = "matplotlib-3.3.0-cp36-cp36m-win32.whl", hash = "sha256:7c9adba58a67d23cc131c4189da56cb1d0f18a237c43188d831a44e4fc5df15a"}, - {file = "matplotlib-3.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:855bb281f3cc8e23ef66064a2beb229674fdf785638091fc82a172e3e84c2780"}, - {file = "matplotlib-3.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7ce8f5364c74aac06abad84d8744d659bd86036e86c4ebf14c75ae4292597b46"}, - {file = "matplotlib-3.3.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:605e4d43b421524ad955a56535391e02866d07bce27c644e2c99e25fb59d63d1"}, - {file = "matplotlib-3.3.0-cp37-cp37m-win32.whl", hash = "sha256:cef05e9a2302f96d6f0666ee70ac7715cbc12e3802d8b8eb80bacd6ab81a0a24"}, - {file = "matplotlib-3.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:bf8d527a2eb9a5db1c9e5e405d1b1c4e66be983620c9ce80af6aae430d9a0c9c"}, - {file = "matplotlib-3.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c06ea133b44805d42f2507cb3503f6647b0c7918f1900b5063f5a8a69c63f6d2"}, - {file = "matplotlib-3.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:09b4096748178bcc764b81587b00ab720eac24965d2bf44ecbe09dcf4e8ed253"}, - {file = "matplotlib-3.3.0-cp38-cp38-win32.whl", hash = "sha256:c1f850908600efa60f81ad14eedbaf7cb17185a2c6d26586ae826ae5ce21f6e0"}, - {file = "matplotlib-3.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2a9d10930406748b50f60c5fa74c399a1c1080aa6ce6e3fe5f38473b02f6f06d"}, - {file = "matplotlib-3.3.0-cp39-cp39-win32.whl", hash = "sha256:244a9088140a4c540e0a2db9c8ada5ad12520efded592a46e5bc43ff8f0fd0aa"}, - {file = "matplotlib-3.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:f74c39621b03cec7bc08498f140192ac26ca940ef20beac6dfad3714d2298b2a"}, - {file = "matplotlib-3.3.0.tar.gz", hash = "sha256:24e8db94948019d531ce0bcd637ac24b1c8f6744ac86d2aa0eb6dbaeb1386f82"}, + {file = "matplotlib-3.3.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:282f8a077a1217f9f2ac178596f27c1ae94abbc6e7b785e1b8f25e83918e9199"}, + {file = "matplotlib-3.3.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:83ae7261f4d5ab387be2caee29c4f499b1566f31c8ac97a0b8ab61afd9e3da92"}, + {file = "matplotlib-3.3.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1f9cf2b8500b833714a193cb24281153f5072d55b2e486009f1e81f0b7da3410"}, + {file = "matplotlib-3.3.1-cp36-cp36m-win32.whl", hash = "sha256:0dc15e1ad84ec06bf0c315e6c4c2cced13a21ce4c2b4955bb75097064a4b1e92"}, + {file = "matplotlib-3.3.1-cp36-cp36m-win_amd64.whl", hash = "sha256:ffbae66e2db70dc330cb3299525f97e1c0efdfc763e04e1a4e08f968c7ad21f0"}, + {file = "matplotlib-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88c6ab4a32a7447dad236b8371612aaba5c967d632ff11999e0478dd687f2c58"}, + {file = "matplotlib-3.3.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:cc2d6b47c8fee89da982a312b54949ec0cd6a7976a8cafb5b62dea6c9883a14d"}, + {file = "matplotlib-3.3.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:636c6330a7dcb18bac114dbeaff314fbbb0c11682f9a9601de69a50e331d18d7"}, + {file = "matplotlib-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:73a493e340064e8fe03207d9333b68baca30d9f0da543ae4af6b6b4f13f0fe05"}, + {file = "matplotlib-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:6739b6cd9278d5cb337df0bd4400ad37bbd04c6dc7aa2c65e1e83a02bc4cc6fd"}, + {file = "matplotlib-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79f0c4730ad422ecb6bda814c9a9b375df36d6bd5a49eaa14e92e5f5e3e95ac3"}, + {file = "matplotlib-3.3.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e4d6d3afc454b4afc0d9d0ed52a8fa40a1b0d8f33c8e143e49a5833a7e32266b"}, + {file = "matplotlib-3.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:96a5e667308dbf45670370d9dffb974e73b15bac0df0b5f3fb0b0ac7a572290e"}, + {file = "matplotlib-3.3.1-cp38-cp38-win32.whl", hash = "sha256:bd8fceaa3494b531d43b6206966ba15705638137fc2dc5da5ee560cf9476867b"}, + {file = "matplotlib-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:1507c2a8e4662f6fa1d3ecc760782b158df8a3244ecc21c1d8dbb1cd0b3f872e"}, + {file = "matplotlib-3.3.1-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c3619ec2a5ead430a4536ebf8c77ea55d8ce36418919f831d35bc657ed5f27e"}, + {file = "matplotlib-3.3.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:9703bc00a94a94c4e94b2ea0fbfbc9d2bb21159733134639fd931b6606c5c47e"}, + {file = "matplotlib-3.3.1.tar.gz", hash = "sha256:87f53bcce90772f942c2db56736788b39332d552461a5cb13f05ff45c1680f0e"}, ] mccabe = [ {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, @@ -2375,8 +2459,8 @@ nltk = [ {file = "nltk-3.5.zip", hash = "sha256:845365449cd8c5f9731f7cb9f8bd6fd0767553b9d53af9eb1b3abf7700936b35"}, ] notebook = [ - {file = "notebook-6.0.3-py3-none-any.whl", hash = "sha256:3edc616c684214292994a3af05eaea4cc043f6b4247d830f3a2f209fa7639a80"}, - {file = "notebook-6.0.3.tar.gz", hash = "sha256:47a9092975c9e7965ada00b9a20f0cf637d001db60d241d479f53c0be117ad48"}, + {file = "notebook-6.1.3-py3-none-any.whl", hash = "sha256:964cc40cff68e473f3778aef9266e867f7703cb4aebdfd250f334efe02f64c86"}, + {file = "notebook-6.1.3.tar.gz", hash = "sha256:9990d51b9931a31e681635899aeb198b4c4b41586a9e87fbfaaed1a71d0a05b6"}, ] numpy = [ {file = "numpy-1.19.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b1cca51512299841bf69add3b75361779962f9cee7d9ee3bb446d5982e925b69"}, @@ -2410,25 +2494,27 @@ nvidia-ml-py3 = [ {file = "nvidia-ml-py3-7.352.0.tar.gz", hash = "sha256:390f02919ee9d73fe63a98c73101061a6b37fa694a793abf56673320f1f51277"}, ] opencv-python = [ - {file = "opencv_python-4.3.0.36-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:c4f1e9d963c8f370284afa87fcf521cc8439a610a500bf8ede27fd64dd9050bd"}, - {file = "opencv_python-4.3.0.36-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:55e1d7a2d11c40ea5b53aabe5c4122038803c7d492505c8f93af077aa7fe2ce1"}, - {file = "opencv_python-4.3.0.36-cp35-cp35m-win32.whl", hash = "sha256:76ddc6daf8607eda1d866395dcf98526ef96f3e616d8c37ccc7629f9aaf6d4d4"}, - {file = "opencv_python-4.3.0.36-cp35-cp35m-win_amd64.whl", hash = "sha256:2fe704e35808cf6b17b793e89fd00e9ef7779f85f274666a4e092671aedd09c0"}, - {file = "opencv_python-4.3.0.36-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:2ec6502bfac01b27ac06daf7bc9f7a4f482a6a0d8e1b30e15c411d478454a19f"}, - {file = "opencv_python-4.3.0.36-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:f67c1d92ff96c6c106f786b7ef9b9ab448fa03ef28cb7bb6f0f7b857b65bc158"}, - {file = "opencv_python-4.3.0.36-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:c93b1198c85175a9fa9a9839c4da55c7ab9c5f57256f2e4211cd6c91d7d422e8"}, - {file = "opencv_python-4.3.0.36-cp36-cp36m-win32.whl", hash = "sha256:1bf486680a16d739f7852a62865b72eb7692df584694815774ba97b471b8bc3f"}, - {file = "opencv_python-4.3.0.36-cp36-cp36m-win_amd64.whl", hash = "sha256:f6fa2834d85c78865ca6e3de563916086cb8c83c3f2ef80924fcd07005f05df9"}, - {file = "opencv_python-4.3.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fa1a6d149a1a5e0bc54c737a59fe38d75384a092ae5e35f9b876fbb621f755c6"}, - {file = "opencv_python-4.3.0.36-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:4b93b5f8df187e4dba9fb25c46fa8cf342c257de144f7c86d75c06416566a199"}, - {file = "opencv_python-4.3.0.36-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:fd457deedcf153dd6805a2b4d891ac2a0969566d3755fbf48a3ffb53978c9ed1"}, - {file = "opencv_python-4.3.0.36-cp37-cp37m-win32.whl", hash = "sha256:ef4ac758a4e2caee80ef9c86b83a279d6f132c9e7ae77957cf74013928814e05"}, - {file = "opencv_python-4.3.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:d765c44827778cbe6bc8f272cd61514e8509b93fd24dd3324cd4abddf2026b11"}, - {file = "opencv_python-4.3.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:677f61332436e22f83a1e4e6f6863a760734fbc8029ba6a8ef0af4554cde6f93"}, - {file = "opencv_python-4.3.0.36-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:eb709245e56f6693d297f8818ff8e6c017fa80fdb5a923c64be623a678c7150e"}, - {file = "opencv_python-4.3.0.36-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:210ab40c8c9dadc7dc9ed7beebe2e0a2415a744f8d6857762a80c8e0fcc477c8"}, - {file = "opencv_python-4.3.0.36-cp38-cp38-win32.whl", hash = "sha256:156e2954d5b38b676e8a24d66703cf15f252e24ec49db7e842a8b5eed46074ba"}, - {file = "opencv_python-4.3.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:1ea08f22246ccd33174d59edfa3f13930bf2c28096568242090bd9d8770fb8a8"}, + {file = "opencv-python-4.4.0.42.tar.gz", hash = "sha256:0039506845d7076e6871c0075227881a84de69799d70ed37c8704d203b740911"}, + {file = "opencv_python-4.4.0.42-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:608dae0444065669fc26fa6bf1653072e40735b33dfa514c74a6165563a99e97"}, + {file = "opencv_python-4.4.0.42-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:fae421571a7709ae0baa9bfd08177165bc1d56d7c79c806d12627d58a6faf2d1"}, + {file = "opencv_python-4.4.0.42-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:a35b3a3540623090ba5fdad7ed97d0d75ca80ee55f5d7c1cecddda723665c0f8"}, + {file = "opencv_python-4.4.0.42-cp35-cp35m-win32.whl", hash = "sha256:177f14625ea164f38b5b6f5c2b316f8ff8163e996cc0432de90f475956a9069a"}, + {file = "opencv_python-4.4.0.42-cp35-cp35m-win_amd64.whl", hash = "sha256:093c1bfa6da24a9d4dde2d54a22b9acfb46f5cb2c50d7387356cf897f0db0ab9"}, + {file = "opencv_python-4.4.0.42-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:f5b82cd49b560e004608ca53ce625e5167b41f0fdc610758d6989083e26b5a03"}, + {file = "opencv_python-4.4.0.42-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:bcb24c4f82fa79f049db4bfd0da1d18a315da66a55aa3d4cde81d1ec18f0a7ff"}, + {file = "opencv_python-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:cb00bbd41268f5fa0fa327ca30f7621a8ece983e0d8ae472e2ffe7ab1617606f"}, + {file = "opencv_python-4.4.0.42-cp36-cp36m-win32.whl", hash = "sha256:78a0796ec15d1b41f5a87c41f339356eb04858749c8845936be532cb3436f898"}, + {file = "opencv_python-4.4.0.42-cp36-cp36m-win_amd64.whl", hash = "sha256:34d0d2c9a80c02d55f83a67c29fc4145a9dcf1fe3ddef0535d0b0d9c7b89b8d2"}, + {file = "opencv_python-4.4.0.42-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:505bd984aae24c489910bbd168e515580d62bc1dbdd5ee36f2c2d42803c4b795"}, + {file = "opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:17663f0469b2944b7d4051d4b1c425235d153777f17310c6990370bbb4d12695"}, + {file = "opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:d19cbbcdc05caf7b41e28898f05c076c94b07647b4556c8327663a40acd4e3bd"}, + {file = "opencv_python-4.4.0.42-cp37-cp37m-win32.whl", hash = "sha256:ccd92a126d253c7bd65b36184fe097a0eea77da4d72d427e1630633bc586233e"}, + {file = "opencv_python-4.4.0.42-cp37-cp37m-win_amd64.whl", hash = "sha256:80a51a797f71ee4a401d281749bb096370007202204bbcd1ecfc9ead58bd3b0b"}, + {file = "opencv_python-4.4.0.42-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:02f7e31c710a7c82229fc4ad98e7e4cf265d19ab52b4451cbe7e33a840fe6595"}, + {file = "opencv_python-4.4.0.42-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:b3ae62990faebefbc3cbc5430f7b6de57bafdcf297134113a9c6d6ccfce4438f"}, + {file = "opencv_python-4.4.0.42-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:fec63240ea3179a2b4176a3256a99682129d75450a15bf2807904600ec64b45a"}, + {file = "opencv_python-4.4.0.42-cp38-cp38-win32.whl", hash = "sha256:324a2c680caae9edbd843a355a2e03792cbd23faf6c24c20dd594fa9aac80765"}, + {file = "opencv_python-4.4.0.42-cp38-cp38-win_amd64.whl", hash = "sha256:a6e1d065a45ec1bf466f47bdf767e0505b244c9470140cf8bab1dd8835f0d3ee"}, ] packaging = [ {file = "packaging-20.4-py2.py3-none-any.whl", hash = "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181"}, @@ -2500,8 +2586,8 @@ promise = [ {file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"}, ] prompt-toolkit = [ - {file = "prompt_toolkit-3.0.5-py3-none-any.whl", hash = "sha256:df7e9e63aea609b1da3a65641ceaf5bc7d05e0a04de5bd45d05dbeffbabf9e04"}, - {file = "prompt_toolkit-3.0.5.tar.gz", hash = "sha256:563d1a4140b63ff9dd587bda9557cffb2fe73650205ab6f4383092fb882e7dc8"}, + {file = "prompt_toolkit-3.0.6-py3-none-any.whl", hash = "sha256:683397077a64cd1f750b71c05afcfc6612a7300cb6932666531e5a54f38ea564"}, + {file = "prompt_toolkit-3.0.6.tar.gz", hash = "sha256:7630ab85a23302839a0f26b31cc24f518e6155dea1ed395ea61b42c45941b6a6"}, ] psutil = [ {file = "psutil-5.7.2-cp27-none-win32.whl", hash = "sha256:f2018461733b23f308c298653c8903d32aaad7873d25e1d228765e91ae42c3f2"}, @@ -2528,6 +2614,10 @@ pycodestyle = [ {file = "pycodestyle-2.6.0-py2.py3-none-any.whl", hash = "sha256:2295e7b2f6b5bd100585ebcb1f616591b652db8a741695b3d8f5d28bdc934367"}, {file = "pycodestyle-2.6.0.tar.gz", hash = "sha256:c58a7d2815e0e8d7972bf1803331fb0152f867bd89adf8a01dfd55085434192e"}, ] +pycparser = [ + {file = "pycparser-2.20-py2.py3-none-any.whl", hash = "sha256:7582ad22678f0fcd81102833f60ef8d0e57288b6b5fb00323d101be910e35705"}, + {file = "pycparser-2.20.tar.gz", hash = "sha256:2d475327684562c3a96cc71adf7dc8c4f0565175cf86b6d7a404ff4c771f15f0"}, +] pydocstyle = [ {file = "pydocstyle-5.0.2-py3-none-any.whl", hash = "sha256:da7831660b7355307b32778c4a0dbfb137d89254ef31a2b2978f50fc0b4d7586"}, {file = "pydocstyle-5.0.2.tar.gz", hash = "sha256:f4f5d210610c2d153fae39093d44224c17429e2ad7da12a8b419aba5c2f614b5"}, @@ -2552,8 +2642,8 @@ pytest = [ {file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, ] pytest-cov = [ - {file = "pytest-cov-2.10.0.tar.gz", hash = "sha256:1a629dc9f48e53512fcbfda6b07de490c374b0c83c55ff7a1720b3fccff0ac87"}, - {file = "pytest_cov-2.10.0-py2.py3-none-any.whl", hash = "sha256:6e6d18092dce6fad667cd7020deed816f858ad3b49d5b5e2b1cc1c97a4dba65c"}, + {file = "pytest-cov-2.10.1.tar.gz", hash = "sha256:47bd0ce14056fdd79f93e1713f88fad7bdcc583dcd7783da86ef2f085a0bb88e"}, + {file = "pytest_cov-2.10.1-py2.py3-none-any.whl", hash = "sha256:45ec2d5182f89a81fc3eb29e3d1ed3113b9e9a873bcddb2a71faaab066110191"}, ] pytest-mock = [ {file = "pytest-mock-3.2.0.tar.gz", hash = "sha256:7122d55505d5ed5a6f3df940ad174b3f606ecae5e9bc379569cdcbd4cd9d2b83"}, @@ -2564,13 +2654,13 @@ python-dateutil = [ {file = "python_dateutil-2.8.1-py2.py3-none-any.whl", hash = "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"}, ] pytype = [ - {file = "pytype-2020.7.24-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:a977290a6c2e12c3f69987bc21a5bf16f55372535c16b59e2480c07f4576ac61"}, - {file = "pytype-2020.7.24-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:d1bc9066c922f1179e7ba2da320094f88942cf36da91654f5e9dc2968b9ba180"}, - {file = "pytype-2020.7.24-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:d2193d30730f3d6e9c1dcb284cfaa1287ec642265cf2a0e972bcb0514c3b80a9"}, - {file = "pytype-2020.7.24-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:17172f17daddcc67c281df11ffbc2f0847b1e71d799180ace79fc398b1b16315"}, - {file = "pytype-2020.7.24-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:3ca6eb0d966d30523f8249cfc37885cfb3faa169587576592e9c96dbba568a5e"}, - {file = "pytype-2020.7.24-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:49d3d6635241c7924ea671a61b47f1ec35d5e64df1f884d0c29279697cf1ca2a"}, - {file = "pytype-2020.7.24.tar.gz", hash = "sha256:9ebc2b1351621323ce8644f5d20fac190ceb8ebb911aa9f880a1a93c6844803e"}, + {file = "pytype-2020.8.10-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:da1977a1aa74fbd237e889c1d29421d490e0be9a91a22efd96fbca2570ef9165"}, + {file = "pytype-2020.8.10-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:e0909b99aff8eff0ece91fd64e00b935f0e4fecb51359d83d742b27db160dd00"}, + {file = "pytype-2020.8.10-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:6b0cd56b411738eb607a299437ac405cc94208875e97ba56332105103676a903"}, + {file = "pytype-2020.8.10-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:08662a6d426a4ef246ba36a807d526734f437451a78f83a140d338e305bc877a"}, + {file = "pytype-2020.8.10-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:768c9ea0b08f40ce8e1ed8b9207862394d770fe3340ebebfd4210a82af530d67"}, + {file = "pytype-2020.8.10-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:cd4399abd26a4a3498eca1ac4ad264cb407b17dfd826f9d5f14d6c06c78cf42a"}, + {file = "pytype-2020.8.10.tar.gz", hash = "sha256:6385b6837a6db69c42eb477e8f7539c0b986ec6753eab4d811553d63d58a7785"}, ] pytz = [ {file = "pytz-2020.1-py2.py3-none-any.whl", hash = "sha256:a494d53b6d39c3c6e44c3bec237336e14305e4f29bbf800b599253057fbb79ed"}, @@ -2616,34 +2706,34 @@ pyyaml = [ {file = "PyYAML-5.3.1.tar.gz", hash = "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d"}, ] pyzmq = [ - {file = "pyzmq-19.0.1-cp27-cp27m-macosx_10_9_intel.whl", hash = "sha256:58688a2dfa044fad608a8e70ba8d019d0b872ec2acd75b7b5e37da8905605891"}, - {file = "pyzmq-19.0.1-cp27-cp27m-win32.whl", hash = "sha256:87c78f6936e2654397ca2979c1d323ee4a889eef536cc77a938c6b5be33351a7"}, - {file = "pyzmq-19.0.1-cp27-cp27m-win_amd64.whl", hash = "sha256:97b6255ae77328d0e80593681826a0479cb7bac0ba8251b4dd882f5145a2293a"}, - {file = "pyzmq-19.0.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:15b4cb21118f4589c4db8be4ac12b21c8b4d0d42b3ee435d47f686c32fe2e91f"}, - {file = "pyzmq-19.0.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:931339ac2000d12fe212e64f98ce291e81a7ec6c73b125f17cf08415b753c087"}, - {file = "pyzmq-19.0.1-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:2a88b8fabd9cc35bd59194a7723f3122166811ece8b74018147a4ed8489e6421"}, - {file = "pyzmq-19.0.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:bafd651b557dd81d89bd5f9c678872f3e7b7255c1c751b78d520df2caac80230"}, - {file = "pyzmq-19.0.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:8952f6ba6ae598e792703f3134af5a01af8f5c7cf07e9a148f05a12b02412cea"}, - {file = "pyzmq-19.0.1-cp35-cp35m-win32.whl", hash = "sha256:54aa24fd60c4262286fc64ca632f9e747c7cc3a3a1144827490e1dc9b8a3a960"}, - {file = "pyzmq-19.0.1-cp35-cp35m-win_amd64.whl", hash = "sha256:dcbc3f30c11c60d709c30a213dc56e88ac016fe76ac6768e64717bd976072566"}, - {file = "pyzmq-19.0.1-cp36-cp36m-macosx_10_9_intel.whl", hash = "sha256:6ca519309703e95d55965735a667809bbb65f52beda2fdb6312385d3e7a6d234"}, - {file = "pyzmq-19.0.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:4ee0bfd82077a3ff11c985369529b12853a4064320523f8e5079b630f9551448"}, - {file = "pyzmq-19.0.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ba6f24431b569aec674ede49cad197cad59571c12deed6ad8e3c596da8288217"}, - {file = "pyzmq-19.0.1-cp36-cp36m-win32.whl", hash = "sha256:956775444d01331c7eb412c5fb9bb62130dfaac77e09f32764ea1865234e2ca9"}, - {file = "pyzmq-19.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b08780e3a55215873b3b8e6e7ca8987f14c902a24b6ac081b344fd430d6ca7cd"}, - {file = "pyzmq-19.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:21f7d91f3536f480cb2c10d0756bfa717927090b7fb863e6323f766e5461ee1c"}, - {file = "pyzmq-19.0.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:bfff5ffff051f5aa47ba3b379d87bd051c3196b0c8a603e8b7ed68a6b4f217ec"}, - {file = "pyzmq-19.0.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:07fb8fe6826a229dada876956590135871de60dbc7de5a18c3bcce2ed1f03c98"}, - {file = "pyzmq-19.0.1-cp37-cp37m-win32.whl", hash = "sha256:342fb8a1dddc569bc361387782e8088071593e7eaf3e3ecf7d6bd4976edff112"}, - {file = "pyzmq-19.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:faee2604f279d31312bc455f3d024f160b6168b9c1dde22bf62d8c88a4deca8e"}, - {file = "pyzmq-19.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5b9d21fc56c8aacd2e6d14738021a9d64f3f69b30578a99325a728e38a349f85"}, - {file = "pyzmq-19.0.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:af0c02cf49f4f9eedf38edb4f3b6bb621d83026e7e5d76eb5526cc5333782fd6"}, - {file = "pyzmq-19.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5f1f2eb22aab606f808163eb1d537ac9a0ba4283fbeb7a62eb48d9103cf015c2"}, - {file = "pyzmq-19.0.1-cp38-cp38-win32.whl", hash = "sha256:f9d7e742fb0196992477415bb34366c12e9bb9a0699b8b3f221ff93b213d7bec"}, - {file = "pyzmq-19.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:5b99c2ae8089ef50223c28bac57510c163bfdff158c9e90764f812b94e69a0e6"}, - {file = "pyzmq-19.0.1-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:cf5d689ba9513b9753959164cf500079383bc18859f58bf8ce06d8d4bef2b054"}, - {file = "pyzmq-19.0.1-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aaa8b40b676576fd7806839a5de8e6d5d1b74981e6376d862af6c117af2a3c10"}, - {file = "pyzmq-19.0.1.tar.gz", hash = "sha256:13a5638ab24d628a6ade8f794195e1a1acd573496c3b85af2f1183603b7bf5e0"}, + {file = "pyzmq-19.0.2-cp27-cp27m-macosx_10_9_intel.whl", hash = "sha256:59f1e54627483dcf61c663941d94c4af9bf4163aec334171686cdaee67974fe5"}, + {file = "pyzmq-19.0.2-cp27-cp27m-win32.whl", hash = "sha256:c36ffe1e5aa35a1af6a96640d723d0d211c5f48841735c2aa8d034204e87eb87"}, + {file = "pyzmq-19.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:0a422fc290d03958899743db091f8154958410fc76ce7ee0ceb66150f72c2c97"}, + {file = "pyzmq-19.0.2-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:c20dd60b9428f532bc59f2ef6d3b1029a28fc790d408af82f871a7db03e722ff"}, + {file = "pyzmq-19.0.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:d46fb17f5693244de83e434648b3dbb4f4b0fec88415d6cbab1c1452b6f2ae17"}, + {file = "pyzmq-19.0.2-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:f1a25a61495b6f7bb986accc5b597a3541d9bd3ef0016f50be16dbb32025b302"}, + {file = "pyzmq-19.0.2-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:ab0d01148d13854de716786ca73701012e07dff4dfbbd68c4e06d8888743526e"}, + {file = "pyzmq-19.0.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:720d2b6083498a9281eaee3f2927486e9fe02cd16d13a844f2e95217f243efea"}, + {file = "pyzmq-19.0.2-cp35-cp35m-win32.whl", hash = "sha256:29d51279060d0a70f551663bc592418bcad7f4be4eea7b324f6dd81de05cb4c1"}, + {file = "pyzmq-19.0.2-cp35-cp35m-win_amd64.whl", hash = "sha256:5120c64646e75f6db20cc16b9a94203926ead5d633de9feba4f137004241221d"}, + {file = "pyzmq-19.0.2-cp36-cp36m-macosx_10_9_intel.whl", hash = "sha256:8a6ada5a3f719bf46a04ba38595073df8d6b067316c011180102ba2a1925f5b5"}, + {file = "pyzmq-19.0.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fa411b1d8f371d3a49d31b0789eb6da2537dadbb2aef74a43aa99a78195c3f76"}, + {file = "pyzmq-19.0.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:00dca814469436455399660247d74045172955459c0bd49b54a540ce4d652185"}, + {file = "pyzmq-19.0.2-cp36-cp36m-win32.whl", hash = "sha256:046b92e860914e39612e84fa760fc3f16054d268c11e0e25dcb011fb1bc6a075"}, + {file = "pyzmq-19.0.2-cp36-cp36m-win_amd64.whl", hash = "sha256:99cc0e339a731c6a34109e5c4072aaa06d8e32c0b93dc2c2d90345dd45fa196c"}, + {file = "pyzmq-19.0.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e36f12f503511d72d9bdfae11cadbadca22ff632ff67c1b5459f69756a029c19"}, + {file = "pyzmq-19.0.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:c40fbb2b9933369e994b837ee72193d6a4c35dfb9a7c573257ef7ff28961272c"}, + {file = "pyzmq-19.0.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d9fc809aa8d636e757e4ced2302569d6e60e9b9c26114a83f0d9d6519c40493"}, + {file = "pyzmq-19.0.2-cp37-cp37m-win32.whl", hash = "sha256:3fa6debf4bf9412e59353defad1f8035a1e68b66095a94ead8f7a61ae90b2675"}, + {file = "pyzmq-19.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:73483a2caaa0264ac717af33d6fb3f143d8379e60a422730ee8d010526ce1913"}, + {file = "pyzmq-19.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:36ab114021c0cab1a423fe6689355e8f813979f2c750968833b318c1fa10a0fd"}, + {file = "pyzmq-19.0.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:8b66b94fe6243d2d1d89bca336b2424399aac57932858b9a30309803ffc28112"}, + {file = "pyzmq-19.0.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:654d3e06a4edc566b416c10293064732516cf8871a4522e0a2ba00cc2a2e600c"}, + {file = "pyzmq-19.0.2-cp38-cp38-win32.whl", hash = "sha256:276ad604bffd70992a386a84bea34883e696a6b22e7378053e5d3227321d9702"}, + {file = "pyzmq-19.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:09d24a80ccb8cbda1af6ed8eb26b005b6743e58e9290566d2a6841f4e31fa8e0"}, + {file = "pyzmq-19.0.2-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:c1a31cd42905b405530e92bdb70a8a56f048c8a371728b8acf9d746ecd4482c0"}, + {file = "pyzmq-19.0.2-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a7e7f930039ee0c4c26e4dfee015f20bd6919cd8b97c9cd7afbde2923a5167b6"}, + {file = "pyzmq-19.0.2.tar.gz", hash = "sha256:296540a065c8c21b26d63e3cea2d1d57902373b16e4256afe46422691903a438"}, ] qtconsole = [ {file = "qtconsole-4.7.5-py2.py3-none-any.whl", hash = "sha256:4f43d0b049eacb7d723772847f0c465feccce0ccb398871a6e146001a22bad23"}, @@ -2696,8 +2786,8 @@ send2trash = [ {file = "Send2Trash-1.5.0.tar.gz", hash = "sha256:60001cc07d707fe247c94f74ca6ac0d3255aabcb930529690897ca2a39db28b2"}, ] sentry-sdk = [ - {file = "sentry-sdk-0.16.2.tar.gz", hash = "sha256:38bb09d0277117f76507c8728d9a5156f09a47ac5175bb8072513859d19a593b"}, - {file = "sentry_sdk-0.16.2-py2.py3-none-any.whl", hash = "sha256:2de15b13836fa3522815a933bd9c887c77f4868071043349f94f1b896c1bcfb8"}, + {file = "sentry-sdk-0.16.5.tar.gz", hash = "sha256:e12eb1c2c01cd9e9cfe70608dbda4ef451f37ef0b7cbb92e5d43f87c341d6334"}, + {file = "sentry_sdk-0.16.5-py2.py3-none-any.whl", hash = "sha256:d359609e23ec9360b61e5ffdfa417e2f6bca281bfb869608c98c169c7e64acd5"}, ] shortuuid = [ {file = "shortuuid-1.0.1-py3-none-any.whl", hash = "sha256:492c7402ff91beb1342a5898bd61ea953985bf24a41cd9f247409aa2e03c8f77"}, @@ -2716,8 +2806,8 @@ snowballstemmer = [ {file = "snowballstemmer-2.0.0.tar.gz", hash = "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"}, ] sphinx = [ - {file = "Sphinx-3.1.2-py3-none-any.whl", hash = "sha256:97dbf2e31fc5684bb805104b8ad34434ed70e6c588f6896991b2fdfd2bef8c00"}, - {file = "Sphinx-3.1.2.tar.gz", hash = "sha256:b9daeb9b39aa1ffefc2809b43604109825300300b987a24f45976c001ba1a8fd"}, + {file = "Sphinx-3.2.1-py3-none-any.whl", hash = "sha256:ce6fd7ff5b215af39e2fcd44d4a321f6694b4530b6f2b2109b64d120773faea0"}, + {file = "Sphinx-3.2.1.tar.gz", hash = "sha256:321d6d9b16fa381a5306e5a0b76cd48ffbc588e6340059a729c6fdd66087e0e8"}, ] sphinx-autodoc-typehints = [ {file = "sphinx-autodoc-typehints-1.11.0.tar.gz", hash = "sha256:bbf0b203f1019b0f9843ee8eef0cff856dc04b341f6dbe1113e37f2ebf243e11"}, @@ -2772,28 +2862,24 @@ toml = [ {file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"}, ] torch = [ - {file = "torch-1.5.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b84fd18fd8216b74a19828433c3beeb1f0d1d29f45dead3be9ed784ae6855966"}, - {file = "torch-1.5.1-cp35-none-macosx_10_6_x86_64.whl", hash = "sha256:5d909a55cd979fec2c9a7aa35012024b9cc106acbc496faf5de798b148406450"}, - {file = "torch-1.5.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:a358cee1d35b86757bf915e320ba776d39c20e60db50779060842efc86f02edd"}, - {file = "torch-1.5.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:0a83f41140222c7cc947aa29ed253f3e6fa490606d3d4acd02bfd9f338e3c707"}, - {file = "torch-1.5.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:70046cf66eb40ead89df25b8dcc571c3007fc9849d4e1d254cc09b4b355374d4"}, - {file = "torch-1.5.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:bb2a3e6c9c9dbfda856bd1b1a55d88789a9488b569ffba9cd6d9aa536ef866ba"}, - {file = "torch-1.5.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c42658f2982591dc4d0459645c9ab26e0ce18aa7ab0993c27c8bcb1c98931d11"}, - {file = "torch-1.5.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:ff1dbeaa017bae66036e8e7a698a5475ac5a0d7b0a690f0a04ac3b1133b1feb3"}, + {file = "torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8"}, + {file = "torch-1.6.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c"}, + {file = "torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76"}, + {file = "torch-1.6.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5"}, + {file = "torch-1.6.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5357873e243bcfa804c32dc341f564e9a4c12addfc9baae4ee857fcc09a0a216"}, + {file = "torch-1.6.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b"}, ] torchsummary = [ {file = "torchsummary-1.5.1-py3-none-any.whl", hash = "sha256:10f41d1743fb918f83293f13183f532ab1bb8f6639a1b89e5f8592ec1919a976"}, {file = "torchsummary-1.5.1.tar.gz", hash = "sha256:981bf689e22e0cf7f95c746002f20a24ad26aa6b9d861134a14bc6ce92230590"}, ] torchvision = [ - {file = "torchvision-0.6.1-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:eb6d7ef73ab8ed756d18c1c11ead3ba1be7b3e2fe5bf475e16a4426d7f3d6eec"}, - {file = "torchvision-0.6.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:c122386a4604a6b611626d817981777bb06e747161a4954e4037c2e297fe1d86"}, - {file = "torchvision-0.6.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9a544f241e1df7485c1fda64f06af4962ad4af1da6caeda352757ea800dd794d"}, - {file = "torchvision-0.6.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:c44819c3c903f2d6d269713a9aa81a9dcb0ab7716b4fc4dbdccf596e6fe894c9"}, - {file = "torchvision-0.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b07e7c5a7274082a8bbe88975c817da10bffaf4020017347d185b5d4eacb891d"}, - {file = "torchvision-0.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:52be3710bf495fee9eb8f1cb4d310edd2da94b127b2ca0aa77075c1001bcbb91"}, - {file = "torchvision-0.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:80a8beaa3416035c9640d153b1f0a7c9b09aac8037a032bbc33fd4bc8657b091"}, - {file = "torchvision-0.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:3af59c61d20bb0bb63dbd50f324fcd975fc5b89fb78e6b5714ae613cd7b734cc"}, + {file = "torchvision-0.7.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a70d80bb8749c1e4a46fa56dc2fc857e98d14600841e02cc2fed766daf96c245"}, + {file = "torchvision-0.7.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:14c0bf60fa26aabaea64ef30b8e5d441ee78d1a5eed568c30806af19bbe6b638"}, + {file = "torchvision-0.7.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c8df7e1d1f3d4e088256be1e8c8d3eb90b016302baa4649742d47ae1531da37"}, + {file = "torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0d1a5adfef4387659c7a0af3b72e16caa0c67224a422050ab65184d13ac9fb13"}, + {file = "torchvision-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f5686e0a0dd511ac33eb9d6279bd34edd9f282dcb7c8ad21e290882c6206504f"}, + {file = "torchvision-0.7.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cfa2b367bc9acf20f18b151d0525970279719e81969c17214effe77245875354"}, ] tornado = [ {file = "tornado-6.0.4-cp35-cp35m-win32.whl", hash = "sha256:5217e601700f24e966ddab689f90b7ea4bd91ff3357c3600fa1045e26d68e55d"}, @@ -2807,8 +2893,8 @@ tornado = [ {file = "tornado-6.0.4.tar.gz", hash = "sha256:0fe2d45ba43b00a41cd73f8be321a44936dc1aba233dee979f17a042b83eb6dc"}, ] tqdm = [ - {file = "tqdm-4.48.0-py2.py3-none-any.whl", hash = "sha256:fcb7cb5b729b60a27f300b15c1ffd4744f080fb483b88f31dc8654b082cc8ea5"}, - {file = "tqdm-4.48.0.tar.gz", hash = "sha256:6baa75a88582b1db6d34ce4690da5501d2a1cb65c34664840a456b2c9f794d29"}, + {file = "tqdm-4.48.2-py2.py3-none-any.whl", hash = "sha256:1a336d2b829be50e46b84668691e0a2719f26c97c62846298dd5ae2937e4d5cf"}, + {file = "tqdm-4.48.2.tar.gz", hash = "sha256:564d632ea2b9cb52979f7956e093e831c28d441c11751682f84c86fc46e4fd21"}, ] traitlets = [ {file = "traitlets-4.3.3-py2.py3-none-any.whl", hash = "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44"}, @@ -2856,8 +2942,8 @@ urllib3 = [ {file = "urllib3-1.25.10.tar.gz", hash = "sha256:91056c15fa70756691db97756772bb1eb9678fa585d9184f24534b100dc60f4a"}, ] wandb = [ - {file = "wandb-0.9.4-py2.py3-none-any.whl", hash = "sha256:6c2b2fff356803f38a2620737f7ba280c5bc770146c711c52a2f4d233b1a0579"}, - {file = "wandb-0.9.4.tar.gz", hash = "sha256:b3afba328455e885fd0af74fb8e40eb9fe1290d18d1e24181ab6f551f888532d"}, + {file = "wandb-0.9.5-py2.py3-none-any.whl", hash = "sha256:f8aa54c75736aafdea600b2d3b803d4d3f999d1ac2568c253552dd557b767d9a"}, + {file = "wandb-0.9.5.tar.gz", hash = "sha256:251ac7bcf9acc1b4ed08b1f3ea9172ed02d1e55a1a9033785e544f009208958e"}, ] watchdog = [ {file = "watchdog-0.10.3.tar.gz", hash = "sha256:4214e1379d128b0588021880ccaf40317ee156d4603ac388b9adcf29165e0c04"}, diff --git a/pyproject.toml b/pyproject.toml index 15d1f57..49bb049 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ sphinx_rtd_theme = "^0.4.3" boltons = "^20.1.0" h5py = "^2.10.0" toml = "^0.10.1" -torch = "^1.5.0" -torchvision = "^0.6.0" +torch = "^1.6.0" +torchvision = "^0.7.0" torchsummary = "^1.5.1" loguru = "^0.5.0" matplotlib = "^3.2.1" @@ -32,6 +32,7 @@ pytest = "^5.4.3" opencv-python = "^4.3.0" nltk = "^3.5" einops = "^0.2.0" +wandb = "^0.9.5" [tool.poetry.dev-dependencies] pytest = "^5.4.2" diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 49ca4c4..3f008c3 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -1,12 +1,121 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.nn.modules.activation.SELU" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.nn.SELU" + ] + }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "import torch" + "a = \"nNone\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "b = a or \"relu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'nnone'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b.lower()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'nNone'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b" ] }, { @@ -986,28 +1095,16 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 51, "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtqdm\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtqdm_auto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package" - ] - } - ], + "outputs": [], "source": [ - "import tqdm.auto.tqdm as tqdm_auto" + "import tqdm" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 52, "metadata": {}, "outputs": [ { @@ -1016,7 +1113,7 @@ "tqdm.notebook.tqdm_notebook" ] }, - "execution_count": 19, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } @@ -1027,25 +1124,50 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tqdm.auto.tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def test():\n", - " for i in range(9):\n", + " for i in tqdm.auto.tqdm(range(9)):\n", " pass\n", - " print(i)" + " print(i)\n", + " " ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 55, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e1d3b25d4ee141e882e316ec54e79d60", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ + "\n", "8\n" ] } @@ -1054,6 +1176,479 @@ "test()" ] }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "from time import sleep" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "41b743273ce14236bcb65782dbcd2e75", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "pbar = tqdm.auto.tqdm([\"a\", \"b\", \"c\", \"d\"], leave=True)\n", + "for char in pbar:\n", + " pbar.set_description(\"Processing %s\" % char)\n", + "# pbar.set_prefix()\n", + " sleep(0.25)\n", + "pbar.set_postfix({\"hej\": 0.32})" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "pbar.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb5ad8d6109f4b1495b8fc7422bafd01", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "with tqdm.auto.tqdm(total=10, bar_format=\"{postfix[0]} {postfix[1][value]:>8.2g}\",\n", + " postfix=[\"Batch\", dict(value=0)]) as t:\n", + " for i in range(10):\n", + " sleep(0.1)\n", + "# t.postfix[2][\"value\"] = 3 \n", + " t.postfix[1][\"value\"] = i / 2\n", + " t.update()" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0b341d49ad074823881e84a538bcad0c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "with tqdm.auto.tqdm(total=100, leave=True) as pbar:\n", + " for i in range(2):\n", + " for i in range(10):\n", + " sleep(0.1)\n", + " pbar.update(10)\n", + " pbar.set_postfix({\"adaf\": 23})\n", + " pbar.set_postfix({\"hej\": 0.32})\n", + " pbar.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "IdentityBlock(\n", + " (blocks): Identity()\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Identity()\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "IdentityBlock(32, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResidualBlock(\n", + " (blocks): Identity()\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ResidualBlock(32, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BasicBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "dummy = torch.ones((1, 32, 224, 224))\n", + "\n", + "block = BasicBlock(32, 64)\n", + "block(dummy).shape\n", + "print(block)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BottleNeckBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (3): ReLU(inplace=True)\n", + " (4): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "dummy = torch.ones((1, 32, 10, 10))\n", + "\n", + "block = BottleNeckBlock(32, 64)\n", + "block(dummy).shape\n", + "print(block)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 128, 24, 24])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy = torch.ones((1, 64, 48, 48))\n", + "\n", + "layer = ResidualLayer(64, 128, block=BasicBlock, num_blocks=3)\n", + "layer(dummy).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(64, 128), (128, 256), (256, 512)]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "blocks_sizes=[64, 128, 256, 512]\n", + "list(zip(blocks_sizes, blocks_sizes[1:]))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "e = Encoder(depths=[1, 1])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from torchsummary import summary" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 32, 15, 15] 800\n", + " BatchNorm2d-2 [-1, 32, 15, 15] 64\n", + " ReLU-3 [-1, 32, 15, 15] 0\n", + " MaxPool2d-4 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-6 [-1, 32, 8, 8] 64\n", + " ReLU-7 [-1, 32, 8, 8] 0\n", + " ReLU-8 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-10 [-1, 32, 8, 8] 64\n", + " ReLU-11 [-1, 32, 8, 8] 0\n", + " ReLU-12 [-1, 32, 8, 8] 0\n", + " BasicBlock-13 [-1, 32, 8, 8] 0\n", + " ResidualLayer-14 [-1, 32, 8, 8] 0\n", + " Conv2d-15 [-1, 64, 4, 4] 2,048\n", + " BatchNorm2d-16 [-1, 64, 4, 4] 128\n", + " Conv2dAuto-17 [-1, 64, 4, 4] 18,432\n", + " BatchNorm2d-18 [-1, 64, 4, 4] 128\n", + " ReLU-19 [-1, 64, 4, 4] 0\n", + " ReLU-20 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-21 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-22 [-1, 64, 4, 4] 128\n", + " ReLU-23 [-1, 64, 4, 4] 0\n", + " ReLU-24 [-1, 64, 4, 4] 0\n", + " BasicBlock-25 [-1, 64, 4, 4] 0\n", + " ResidualLayer-26 [-1, 64, 4, 4] 0\n", + "================================================================\n", + "Total params: 77,152\n", + "Trainable params: 77,152\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 0.43\n", + "Params size (MB): 0.29\n", + "Estimated Total Size (MB): 0.73\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(e, (1, 28, 28), device=\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "resnet = ResidualNetwork(1, 80, activation=\"selu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 32, 15, 15] 800\n", + " BatchNorm2d-2 [-1, 32, 15, 15] 64\n", + " SELU-3 [-1, 32, 15, 15] 0\n", + " MaxPool2d-4 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-6 [-1, 32, 8, 8] 64\n", + " SELU-7 [-1, 32, 8, 8] 0\n", + " SELU-8 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-10 [-1, 32, 8, 8] 64\n", + " SELU-11 [-1, 32, 8, 8] 0\n", + " SELU-12 [-1, 32, 8, 8] 0\n", + " BasicBlock-13 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-14 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-15 [-1, 32, 8, 8] 64\n", + " SELU-16 [-1, 32, 8, 8] 0\n", + " SELU-17 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-18 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-19 [-1, 32, 8, 8] 64\n", + " SELU-20 [-1, 32, 8, 8] 0\n", + " SELU-21 [-1, 32, 8, 8] 0\n", + " BasicBlock-22 [-1, 32, 8, 8] 0\n", + " ResidualLayer-23 [-1, 32, 8, 8] 0\n", + " Conv2d-24 [-1, 64, 4, 4] 2,048\n", + " BatchNorm2d-25 [-1, 64, 4, 4] 128\n", + " Conv2dAuto-26 [-1, 64, 4, 4] 18,432\n", + " BatchNorm2d-27 [-1, 64, 4, 4] 128\n", + " SELU-28 [-1, 64, 4, 4] 0\n", + " SELU-29 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-30 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-31 [-1, 64, 4, 4] 128\n", + " SELU-32 [-1, 64, 4, 4] 0\n", + " SELU-33 [-1, 64, 4, 4] 0\n", + " BasicBlock-34 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-35 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-36 [-1, 64, 4, 4] 128\n", + " SELU-37 [-1, 64, 4, 4] 0\n", + " SELU-38 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-39 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-40 [-1, 64, 4, 4] 128\n", + " SELU-41 [-1, 64, 4, 4] 0\n", + " SELU-42 [-1, 64, 4, 4] 0\n", + " BasicBlock-43 [-1, 64, 4, 4] 0\n", + " ResidualLayer-44 [-1, 64, 4, 4] 0\n", + " Encoder-45 [-1, 64, 4, 4] 0\n", + " Reduce-46 [-1, 64] 0\n", + " Linear-47 [-1, 80] 5,200\n", + " Decoder-48 [-1, 80] 0\n", + "================================================================\n", + "Total params: 174,896\n", + "Trainable params: 174,896\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 0.65\n", + "Params size (MB): 0.67\n", + "Estimated Total Size (MB): 1.32\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(resnet, (1, 28, 28), device=\"cpu\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/notebooks/01-look-at-emnist.ipynb b/src/notebooks/01-look-at-emnist.ipynb index a68b418..8648afb 100644 --- a/src/notebooks/01-look-at-emnist.ipynb +++ b/src/notebooks/01-look-at-emnist.ipynb @@ -31,12 +31,31 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "dataset = EmnistDataset()\n", - "dataset.load_emnist_dataset()" + "dataset = EmnistDataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Tensor" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(dataset.data)" ] }, { diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 96f84e5..49ebad3 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -8,6 +8,7 @@ from loguru import logger import numpy as np from PIL import Image import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST from torchvision.transforms import Compose, Normalize, ToTensor @@ -183,12 +184,8 @@ class EmnistDataset(Dataset): self.input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes - # Placeholders - self.data = None - self.targets = None - # Load dataset. - self.load_emnist_dataset() + self.data, self.targets = self.load_emnist_dataset() @property def mapper(self) -> EmnistMapper: @@ -199,9 +196,7 @@ class EmnistDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches samples from the dataset. Args: @@ -239,11 +234,13 @@ class EmnistDataset(Dataset): f"Mapping: {self.mapper.mapping}\n" ) - def _sample_to_balance(self) -> None: + def _sample_to_balance( + self, data: Tensor, targets: Tensor + ) -> Tuple[np.ndarray, np.ndarray]: """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(self.seed) - x = self.data - y = self.targets + x = data + y = targets num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_indices = [] for label in np.unique(y.flatten()): @@ -253,20 +250,22 @@ class EmnistDataset(Dataset): indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] y_sampled = y[indices] - self.data = x_sampled - self.targets = y_sampled + data = x_sampled + targets = y_sampled + return data, targets - def _subsample(self) -> None: + def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """Subsamples the dataset to the specified fraction.""" - x = self.data - y = self.targets + x = data + y = targets num_samples = int(x.shape[0] * self.subsample_fraction) x_sampled = x[:num_samples] y_sampled = y[:num_samples] self.data = x_sampled self.targets = y_sampled + return data, targets - def load_emnist_dataset(self) -> None: + def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]: """Fetch the EMNIST dataset.""" dataset = EMNIST( root=DATA_DIRNAME, @@ -277,11 +276,13 @@ class EmnistDataset(Dataset): target_transform=None, ) - self.data = dataset.data - self.targets = dataset.targets + data = dataset.data + targets = dataset.targets if self.sample_to_balance: - self._sample_to_balance() + data, targets = self._sample_to_balance(data, targets) if self.subsample_fraction is not None: - self._subsample() + data, targets = self._subsample(data, targets) + + return data, targets diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index d64a991..b0617f5 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,6 +8,7 @@ import h5py from loguru import logger import numpy as np import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor @@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. Args: - index (Union[int, torch.Tensor]): Either a list or int of indices/index. + index (Union[int, Tensor]): Either a list or int of indices/index. Returns: - Tuple[torch.Tensor, torch.Tensor]: Data target pair. + Tuple[Tensor, Tensor]: Data target pair. """ if torch.is_tensor(index): diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 6d40b49..74fd223 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -53,8 +53,8 @@ class Model(ABC): """ - # Fetch data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._load_data_loader( + # Configure data loaders and dataset info. + dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( data_loader_args ) self._input_shape = self._mapper.input_shape @@ -70,16 +70,19 @@ class Model(ABC): else: self._device = device - # Load network. - self._network, self._network_args = self._load_network(network_fn, network_args) + # Configure network. + self._network, self._network_args = self._configure_network( + network_fn, network_args + ) # To device. self._network.to(self._device) - # Set training objects. - self._criterion = self._load_criterion(criterion, criterion_args) - self._optimizer = self._load_optimizer(optimizer, optimizer_args) - self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args) + # Configure training objects. + self._criterion = self._configure_criterion(criterion, criterion_args) + self._optimizer, self._lr_scheduler = self._configure_optimizers( + optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + ) # Experiment directory. self.model_dir = None @@ -87,7 +90,7 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - def _load_data_loader( + def _configure_data_loader( self, data_loader_args: Optional[Dict] ) -> Tuple[str, Dict, EmnistMapper]: """Loads data loader, dataset name, and dataset mapper.""" @@ -102,7 +105,7 @@ class Model(ABC): data_loaders = None return dataset_name, data_loaders, mapper - def _load_network( + def _configure_network( self, network_fn: Type[nn.Module], network_args: Optional[Dict] ) -> Tuple[Type[nn.Module], Dict]: """Loads the network.""" @@ -113,7 +116,7 @@ class Model(ABC): network = network_fn(**network_args) return network, network_args - def _load_criterion( + def _configure_criterion( self, criterion: Optional[Callable], criterion_args: Optional[Dict] ) -> Optional[Callable]: """Loads the criterion.""" @@ -123,27 +126,27 @@ class Model(ABC): _criterion = None return _criterion - def _load_optimizer( - self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads the optimizer.""" + def _configure_optimizers( + self, + optimizer: Optional[Callable], + optimizer_args: Optional[Dict], + lr_scheduler: Optional[Callable], + lr_scheduler_args: Optional[Dict], + ) -> Tuple[Optional[Callable], Optional[Callable]]: + """Loads the optimizers.""" if optimizer is not None: _optimizer = optimizer(self._network.parameters(), **optimizer_args) else: _optimizer = None - return _optimizer - def _load_lr_scheduler( - self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads learning rate scheduler.""" if self._optimizer and lr_scheduler is not None: if "OneCycleLR" in str(lr_scheduler): lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) else: _lr_scheduler = None - return _lr_scheduler + + return _optimizer, _lr_scheduler @property def __name__(self) -> str: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0a0ab2d..0fd7afd 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -44,6 +44,7 @@ class CharacterModel(Model): self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) + @torch.no_grad() def predict_on_image( self, image: Union[np.ndarray, torch.Tensor] ) -> Tuple[str, float]: @@ -64,10 +65,9 @@ class CharacterModel(Model): # If the image is an unscaled tensor. image = image.type("torch.FloatTensor") / 255 - with torch.no_grad(): - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.network(image) + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index e6b6946..a83ca35 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,5 +1,6 @@ """Network modules.""" from .lenet import LeNet from .mlp import MLP +from .residual_network import ResidualNetwork -__all__ = ["MLP", "LeNet"] +__all__ = ["MLP", "LeNet", "ResidualNetwork"] diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index cbc58fc..91d3f2c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class LeNet(nn.Module): """LeNet network.""" @@ -16,8 +18,7 @@ class LeNet(nn.Module): hidden_size: Tuple[int, ...] = (9216, 128), dropout_rate: float = 0.2, output_size: int = 10, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: Optional[str] = "relu", ) -> None: """The LeNet network. @@ -28,18 +29,12 @@ class LeNet(nn.Module): Defaults to (9216, 128). dropout_rate (float): The dropout rate. Defaults to 0.2. output_size (int): Number of classes. Defaults to 10. - activation_fn (Optional[Callable]): The non-linear activation function. Defaults to - nn.ReLU(inplace). - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) self.layers = [ nn.Conv2d( @@ -66,7 +61,7 @@ class LeNet(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 2fbab8f..6f61b5d 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -1,9 +1,9 @@ """Miscellaneous neural network functionality.""" -from typing import Tuple +from typing import Tuple, Type from einops import rearrange import torch -from torch.nn import Unfold +from torch import nn def sliding_window( @@ -20,10 +20,24 @@ def sliding_window( torch.Tensor: A tensor with the shape (batch, patches, height, width). """ - unfold = Unfold(kernel_size=patch_size, stride=stride) + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. patches = unfold(images).unsqueeze(1) patches = rearrange( patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] ) return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index ac2c825..acebdaa 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class MLP(nn.Module): """Multi layered perceptron network.""" @@ -16,8 +18,7 @@ class MLP(nn.Module): hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: str = "relu", ) -> None: """Initialization of the MLP network. @@ -27,18 +28,13 @@ class MLP(nn.Module): hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to - None. - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) if isinstance(hidden_size, int): hidden_size = [hidden_size] * num_layers @@ -65,7 +61,7 @@ class MLP(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 23394b0..47e351a 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -1 +1,315 @@ """Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.misc import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class Encoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: List[int] = (32, 64), + depths: List[int] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + *args, + **kwargs + ) -> None: + super().__init__() + + self.block_sizes = block_sizes + self.depths = depths + self.activation = activation + + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=3, + stride=2, + padding=3, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + return self.blocks(x) + + +class Decoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = Encoder(in_channels, *args, **kwargs) + self.decoder = Decoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt index 81ef9be..676eb44 100644 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt and b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt index 49bd166..86cf103 100644 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt and b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt new file mode 100644 index 0000000..008beb2 Binary files /dev/null and b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt differ diff --git a/src/training/callbacks/__init__.py b/src/training/callbacks/__init__.py deleted file mode 100644 index fbcc285..0000000 --- a/src/training/callbacks/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR -from .wandb_callbacks import WandbCallback, WandbImageLogger - -__all__ = [ - "Callback", - "CallbackList", - "Checkpoint", - "EarlyStopping", - "WandbCallback", - "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", - "ReduceLROnPlateau", - "StepLR", -] diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py deleted file mode 100644 index e0d91e6..0000000 --- a/src/training/callbacks/base.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: - """Mode keys for CallbackList.""" - - TRAIN = "train" - VALIDATION = "validation" - - -class Callback: - """Metaclass for callbacks used in training.""" - - def __init__(self) -> None: - """Initializes the Callback instance.""" - self.model = None - - def set_model(self, model: Type[Model]) -> None: - """Set the model.""" - self.model = model - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - pass - - def on_fit_end(self) -> None: - """Called when fit ends.""" - pass - - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch. Only used in training mode.""" - pass - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch. Only used in training mode.""" - pass - - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - pass - - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - pass - - -class CallbackList: - """Container for abstracting away callback calls.""" - - mode_keys = ModeKeys() - - def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: - """Container for `Callback` instances. - - This object wraps a list of `Callback` instances and allows them all to be - called via a single end point. - - Args: - model (Type[Model]): A `Model` instance. - callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - - """ - - self._callbacks = callbacks or [] - if model: - self.set_model(model) - - def set_model(self, model: Type[Model]) -> None: - """Set the model for all callbacks.""" - self.model = model - for callback in self._callbacks: - callback.set_model(model=self.model) - - def append(self, callback: Type[Callback]) -> None: - """Append new callback to callback list.""" - self.callbacks.append(callback) - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - for callback in self._callbacks: - callback.on_fit_begin() - - def on_fit_end(self) -> None: - """Called when fit ends.""" - for callback in self._callbacks: - callback.on_fit_end() - - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_begin(epoch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_end(epoch, logs) - - def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Dict = {} - ) -> None: - """Helper function for all batch_{begin | end} methods.""" - if hook == "begin": - self._call_batch_begin_hook(mode, batch, logs) - elif hook == "end": - self._call_batch_end_hook(mode, batch, logs) - else: - raise ValueError(f"Unrecognized hook {hook}.") - - def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: - """Helper function for all `on_*_batch_begin` methods.""" - hook_name = f"on_{mode}_batch_begin" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: - """Helper function for all `on_*_batch_end` methods.""" - hook_name = f"on_{mode}_batch_end" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Dict = {} - ) -> None: - """Helper function for `on_*_batch_begin` methods.""" - for callback in self._callbacks: - hook = getattr(callback, hook_name) - hook(batch, logs) - - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch) - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch) - - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch) - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch) - - def __iter__(self) -> iter: - """Iter function for callback list.""" - return iter(self._callbacks) - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/src/training/callbacks/early_stopping.py b/src/training/callbacks/early_stopping.py deleted file mode 100644 index c9b7907..0000000 --- a/src/training/callbacks/early_stopping.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from training.callbacks import Callback - - -class EarlyStopping(Callback): - """Stops training when a monitored metric stops improving.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - monitor: str = "val_loss", - min_delta: float = 0.0, - patience: int = 3, - mode: str = "auto", - ) -> None: - """Initializes the EarlyStopping callback. - - Args: - monitor (str): Description of parameter `monitor`. Defaults to "val_loss". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - patience (int): Description of parameter `patience`. Defaults to 3. - mode (str): Description of parameter `mode`. Defaults to "auto". - - """ - super().__init__() - self.monitor = monitor - self.patience = patience - self.min_delta = torch.tensor(min_delta) - self.mode = mode - self.wait_count = 0 - self.stopped_epoch = 0 - - if mode not in ["auto", "min", "max"]: - logger.warning( - f"EarlyStopping mode {mode} is unkown, fallback to auto mode." - ) - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." - ) - - self.torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_fit_begin(self) -> Union[torch.lt, torch.gt]: - """Reset the early stopping variables for reuse.""" - self.wait_count = 0 - self.stopped_epoch = 0 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Computes the early stop criterion.""" - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - self.wait_count = 0 - else: - self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = epoch - self.model.stop_training = True - - def on_fit_end(self) -> None: - """Logs if early stopping was used.""" - if self.stopped_epoch > 0: - logger.info( - f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." - ) - - def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]: - """Extracts the monitor value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return torch.tensor(monitor_value) diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/callbacks/lr_schedulers.py deleted file mode 100644 index 00c7e9b..0000000 --- a/src/training/callbacks/lr_schedulers.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from training.callbacks import Callback - -from text_recognizer.models import Model - - -class StepLR(Callback): - """Callback for StepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py deleted file mode 100644 index 6ada6df..0000000 --- a/src/training/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Callbacks using wandb.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from torchvision.transforms import Compose, ToTensor -from training.callbacks import Callback -import wandb - -from text_recognizer.datasets import Transpose -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): - """A custom W&B metric logger for the trainer.""" - - def __init__(self, log_batch_frequency: int = None) -> None: - """Short summary. - - Args: - log_batch_frequency (int): If None, metrics will be logged every epoch. - If set to an integer, callback will log every metrics every log_batch_frequency. - - """ - super().__init__() - self.log_batch_frequency = log_batch_frequency - - def _on_batch_end(self, batch: int, logs: Dict) -> None: - if self.log_batch_frequency and batch % self.log_batch_frequency == 0: - wandb.log(logs, commit=True) - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Logs training metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Logs validation metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Logs at epoch end.""" - wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - example_indices: Optional[List] = None, - num_examples: int = 4, - transfroms: Optional[Callable] = None, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to - None. - - """ - - super().__init__() - self.example_indices = example_indices - self.num_examples = num_examples - self.transfroms = transfroms - if self.transfroms is None: - self.transforms = Compose([Transpose()]) - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - data_loader = self.model.data_loaders["val"] - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(data_loader.dataset.data), self.num_examples - ) - self.val_images = data_loader.dataset.data[self.example_indices] - self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.val_images): - image = self.transforms(image) - pred, conf = self.model.predict_on_image(image) - ground_truth = self.model.mapper(int(self.val_targets[i])) - caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" - images.append(wandb.Image(image, caption=caption)) - - wandb.log({"examples": images}, commit=False) diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 355305c..bae02ac 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -9,25 +9,32 @@ experiments: seed: 4711 data_loader_args: splits: [train, val] - batch_size: 256 shuffle: true num_workers: 8 cuda: true model: CharacterModel metrics: [accuracy] - network: MLP + # network: MLP + # network_args: + # input_size: 784 + # hidden_size: 512 + # output_size: 80 + # num_layers: 3 + # dropout_rate: 0 + # activation_fn: SELU + network: ResidualNetwork network_args: - input_size: 784 - output_size: 62 - num_layers: 3 - activation_fn: GELU + in_channels: 1 + num_classes: 80 + depths: [1, 1] + block_sizes: [128, 256] # network: LeNet # network_args: # output_size: 62 # activation_fn: GELU train_args: batch_size: 256 - epochs: 16 + epochs: 32 criterion: CrossEntropyLoss criterion_args: weight: null @@ -43,20 +50,24 @@ experiments: # centered: false optimizer: AdamW optimizer_args: - lr: 1.e-2 + lr: 1.e-03 betas: [0.9, 0.999] eps: 1.e-08 - weight_decay: 0 + # weight_decay: 5.e-4 amsgrad: false # lr_scheduler: null lr_scheduler: OneCycleLR lr_scheduler_args: - max_lr: 1.e-3 - epochs: 16 - callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] + max_lr: 1.e-03 + epochs: 32 + anneal_strategy: linear + callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] callback_args: Checkpoint: monitor: val_accuracy + ProgressBar: + epochs: 32 + log_batch_frequency: 100 EarlyStopping: monitor: val_loss min_delta: 0.0 @@ -68,5 +79,5 @@ experiments: num_examples: 4 OneCycleLR: null - verbosity: 2 # 0, 1, 2 + verbosity: 1 # 0, 1, 2 resume_experiment: null diff --git a/src/training/population_based_training/__init__.py b/src/training/population_based_training/__init__.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/population_based_training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/population_based_training/population_based_training.py b/src/training/population_based_training/population_based_training.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/population_based_training/population_based_training.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 97c0304..4c3f9ba 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -7,11 +7,11 @@ from loguru import logger import yaml -# flake8: noqa: S404,S607,S603 def run_experiments(experiments_filename: str) -> None: """Run experiment from file.""" with open(experiments_filename) as f: experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) for index in range(num_experiments): experiment_config = experiments_config["experiments"][index] @@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None: type=str, help="Filename of Yaml file of experiments to run.", ) -def main(experiments_filename: str) -> None: +def run_cli(experiments_filename: str) -> None: """Parse command-line arguments and run experiments from provided file.""" run_experiments(experiments_filename) if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index d278dc2..8c063ff 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,20 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm -from training.callbacks import CallbackList from training.gpu_manager import GPUManager -from training.train import Trainer +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer import wandb import yaml +from text_recognizer.models import Model + EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" @@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: +def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ @@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] """Loads all modules and arguments.""" # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + train_args = experiment_config.get("train_args", {}) + data_loader_args["batch_size"] = train_args["batch_size"] data_loader_args["dataset"] = experiment_config["dataset"] data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) @@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] optimizer_args = experiment_config.get("optimizer_args", {}) # Callbacks - callback_modules = importlib.import_module("training.callbacks") + callback_modules = importlib.import_module("training.trainer.callbacks") callbacks = [ getattr(callback_modules, callback)( **check_args(experiment_config["callback_args"][callback]) @@ -208,6 +212,7 @@ def run_experiment( with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) + # Train the model. trainer = Trainer( model=model, model_dir=model_dir, @@ -247,7 +252,7 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: +def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: """Run experiment.""" if gpu < 0: gpu_manager = GPUManager(True) @@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/train.py b/src/training/train.py deleted file mode 100644 index aaa0430..0000000 --- a/src/training/train.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Training script for PyTorch models.""" - -from pathlib import Path -import time -from typing import Dict, List, Optional, Tuple, Type - -from loguru import logger -import numpy as np -import torch -from tqdm import tqdm, trange -from training.callbacks import Callback, CallbackList -from training.util import RunningAverage -import wandb - -from text_recognizer.models import Model - - -torch.backends.cudnn.benchmark = True -np.random.seed(4711) -torch.manual_seed(4711) -torch.cuda.manual_seed(4711) - - -class Trainer: - """Trainer for training PyTorch models.""" - - def __init__( - self, - model: Type[Model], - model_dir: Path, - train_args: Dict, - callbacks: CallbackList, - checkpoint_path: Optional[Path] = None, - ) -> None: - """Initialization of the Trainer. - - Args: - model (Type[Model]): A model object. - model_dir (Path): Path to the model directory. - train_args (Dict): The training arguments. - callbacks (CallbackList): List of callbacks to be called. - checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. - - """ - self.model = model - self.model_dir = model_dir - self.checkpoint_path = checkpoint_path - self.start_epoch = 1 - self.epochs = train_args["epochs"] + self.start_epoch - self.callbacks = callbacks - - if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1 - - # Parse the name of the experiment. - experiment_dir = str(self.model_dir.parents[1]).split("/") - self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] - - def training_step( - self, - batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: - """Performs the training step.""" - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - # Forward pass. - # Get the network prediction. - output = self.model.network(data) - - # Compute the loss. - loss = self.model.criterion(output, targets) - - # Backward pass. - # Clear the previous gradients. - self.model.optimizer.zero_grad() - - # Compute the gradients. - loss.backward() - - # Perform updates using calculated gradients. - self.model.optimizer.step() - - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss_avg() - return metrics - - def train(self) -> None: - """Runs the training loop for one epoch.""" - # Set model to traning mode. - self.model.train() - - # Running average for the loss. - loss_avg = RunningAverage() - - data_loader = self.model.data_loaders["train"] - - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - for batch, samples in enumerate(data_loader): - self.callbacks.on_train_batch_begin(batch) - - metrics = self.training_step(batch, samples, loss_avg) - - self.callbacks.on_train_batch_end(batch, logs=metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() - - def validation_step( - self, - batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: - """Performs the validation step.""" - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - # Forward pass. - # Get the network prediction. - output = self.model.network(data) - - # Compute the loss. - loss = self.model.criterion(output, targets) - - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss.item() - - return metrics - - def validate(self, epoch: Optional[int] = None) -> Dict: - """Runs the validation loop for one epoch.""" - # Set model to eval mode. - self.model.eval() - - # Running average for the loss. - data_loader = self.model.data_loaders["val"] - - # Running average for the loss. - loss_avg = RunningAverage() - - # Summary for the current eval loop. - summary = [] - - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - with torch.no_grad(): - for batch, samples in enumerate(data_loader): - self.callbacks.on_validation_batch_begin(batch) - - metrics = self.validation_step(batch, samples, loss_avg) - - self.callbacks.on_validation_batch_end(batch, logs=metrics) - - summary.append(metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() - - # Compute mean of all metrics. - metrics_mean = { - "val_" + metric: np.mean([x[metric] for x in summary]) - for metric in summary[0] - } - if epoch: - logger.debug( - f"Validation metrics at epoch {epoch} - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - else: - logger.debug( - "Validation metrics - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - - return metrics_mean - - def fit(self) -> None: - """Runs the training and evaluation loop.""" - - logger.debug(f"Running an experiment called {self.experiment_name}.") - t_start = time.time() - - self.callbacks.on_fit_begin() - - # TODO: fix progress bar as callback. - # Run the training loop. - for epoch in trange( - self.start_epoch, - self.epochs, - leave=False, - bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}", - desc="Epoch", - ): - self.callbacks.on_epoch_begin(epoch) - - # Perform one training pass over the training set. - self.train() - - # Evaluate the model on the validation set. - val_metrics = self.validate(epoch) - - self.callbacks.on_epoch_end(epoch, logs=val_metrics) - - if self.model.stop_training: - break - - # Calculate the total training time. - t_end = time.time() - t_training = t_end - t_start - - self.callbacks.on_fit_end() - - logger.info(f"Training took {t_training:.2f} s.") diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py new file mode 100644 index 0000000..de41bfb --- /dev/null +++ b/src/training/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Trainer modules.""" +from .train import Trainer diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py new file mode 100644 index 0000000..5942276 --- /dev/null +++ b/src/training/trainer/callbacks/__init__.py @@ -0,0 +1,21 @@ +"""The callback modules used in the training script.""" +from .base import Callback, CallbackList, Checkpoint +from .early_stopping import EarlyStopping +from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .progress_bar import ProgressBar +from .wandb_callbacks import WandbCallback, WandbImageLogger + +__all__ = [ + "Callback", + "CallbackList", + "Checkpoint", + "EarlyStopping", + "WandbCallback", + "WandbImageLogger", + "CyclicLR", + "MultiStepLR", + "OneCycleLR", + "ProgressBar", + "ReduceLROnPlateau", + "StepLR", +] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py new file mode 100644 index 0000000..8df94f3 --- /dev/null +++ b/src/training/trainer/callbacks/base.py @@ -0,0 +1,248 @@ +"""Metaclass for callback functions.""" + +from enum import Enum +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch + +from text_recognizer.models import Model + + +class ModeKeys: + """Mode keys for CallbackList.""" + + TRAIN = "train" + VALIDATION = "validation" + + +class Callback: + """Metaclass for callbacks used in training.""" + + def __init__(self) -> None: + """Initializes the Callback instance.""" + self.model = None + + def set_model(self, model: Type[Model]) -> None: + """Set the model.""" + self.model = model + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + pass + + def on_fit_end(self) -> None: + """Called when fit ends.""" + pass + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch. Only used in training mode.""" + pass + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch. Only used in training mode.""" + pass + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + +class CallbackList: + """Container for abstracting away callback calls.""" + + mode_keys = ModeKeys() + + def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: + """Container for `Callback` instances. + + This object wraps a list of `Callback` instances and allows them all to be + called via a single end point. + + Args: + model (Type[Model]): A `Model` instance. + callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + + """ + + self._callbacks = callbacks or [] + if model: + self.set_model(model) + + def set_model(self, model: Type[Model]) -> None: + """Set the model for all callbacks.""" + self.model = model + for callback in self._callbacks: + callback.set_model(model=self.model) + + def append(self, callback: Type[Callback]) -> None: + """Append new callback to callback list.""" + self.callbacks.append(callback) + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + for callback in self._callbacks: + callback.on_fit_begin() + + def on_fit_end(self) -> None: + """Called when fit ends.""" + for callback in self._callbacks: + callback.on_fit_end() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_begin(epoch, logs) + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_end(epoch, logs) + + def _call_batch_hook( + self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all batch_{begin | end} methods.""" + if hook == "begin": + self._call_batch_begin_hook(mode, batch, logs) + elif hook == "end": + self._call_batch_end_hook(mode, batch, logs) + else: + raise ValueError(f"Unrecognized hook {hook}.") + + def _call_batch_begin_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_begin` methods.""" + hook_name = f"on_{mode}_batch_begin" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_end_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_end` methods.""" + hook_name = f"on_{mode}_batch_end" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_hook_helper( + self, hook_name: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for `on_*_batch_begin` methods.""" + for callback in self._callbacks: + hook = getattr(callback, hook_name) + hook(batch, logs) + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) + + def __iter__(self) -> iter: + """Iter function for callback list.""" + return iter(self._callbacks) + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py new file mode 100644 index 0000000..02b431f --- /dev/null +++ b/src/training/trainer/callbacks/early_stopping.py @@ -0,0 +1,108 @@ +"""Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback + + +class EarlyStopping(Callback): + """Stops training when a monitored metric stops improving.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + monitor: str = "val_loss", + min_delta: float = 0.0, + patience: int = 3, + mode: str = "auto", + ) -> None: + """Initializes the EarlyStopping callback. + + Args: + monitor (str): Description of parameter `monitor`. Defaults to "val_loss". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + patience (int): Description of parameter `patience`. Defaults to 3. + mode (str): Description of parameter `mode`. Defaults to "auto". + + """ + super().__init__() + self.monitor = monitor + self.patience = patience + self.min_delta = torch.tensor(min_delta) + self.mode = mode + self.wait_count = 0 + self.stopped_epoch = 0 + + if mode not in ["auto", "min", "max"]: + logger.warning( + f"EarlyStopping mode {mode} is unkown, fallback to auto mode." + ) + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." + ) + + self.torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_fit_begin(self) -> Union[torch.lt, torch.gt]: + """Reset the early stopping variables for reuse.""" + self.wait_count = 0 + self.stopped_epoch = 0 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Computes the early stop criterion.""" + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + + def on_fit_end(self) -> None: + """Logs if early stopping was used.""" + if self.stopped_epoch > 0: + logger.info( + f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." + ) + + def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: + """Extracts the monitor value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return torch.tensor(monitor_value) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py new file mode 100644 index 0000000..ba2226a --- /dev/null +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): + """Callback for StepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class MultiStepLR(Callback): + """Callback for MultiStepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): + """Callback for ReduceLROnPlateau.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + val_loss = logs["val_loss"] + self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): + """Callback for CyclicLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() + + +class OneCycleLR(Callback): + """Callback for OneCycleLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..1970747 --- /dev/null +++ b/src/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,61 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): + """A TQDM progress bar for the training loop.""" + + def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: + """Initializes the tqdm callback.""" + self.epochs = epochs + self.log_batch_frequency = log_batch_frequency + self.progress_bar = None + self.val_metrics = {} + + def _configure_progress_bar(self) -> None: + """Configures the tqdm progress bar with custom bar format.""" + self.progress_bar = tqdm( + total=len(self.model.data_loaders["train"]), + leave=True, + unit="step", + mininterval=self.log_batch_frequency, + bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + ) + + def _key_abbreviations(self, logs: Dict) -> Dict: + """Changes the length of keys, so that the progress bar fits better.""" + + def rename(key: str) -> str: + """Renames accuracy to acc.""" + return key.replace("accuracy", "acc") + + return {rename(key): value for key, value in logs.items()} + + def on_fit_begin(self) -> None: + """Creates a tqdm progress bar.""" + self._configure_progress_bar() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: + """Updates the description with the current epoch.""" + self.progress_bar.reset() + self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """At the end of each epoch, the validation metrics are updated to the progress bar.""" + self.val_metrics = logs + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_train_batch_end(self, batch: int, logs: Dict) -> None: + """Updates the progress bar for each training step.""" + if self.val_metrics: + logs.update(self.val_metrics) + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_fit_end(self) -> None: + """Closes the tqdm progress bar.""" + self.progress_bar.close() diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..e44c745 --- /dev/null +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -0,0 +1,93 @@ +"""Callback for W&B.""" +from typing import Callable, Dict, List, Optional, Type + +import numpy as np +from torchvision.transforms import Compose, ToTensor +from training.trainer.callbacks import Callback +import wandb + +from text_recognizer.datasets import Transpose +from text_recognizer.models.base import Model + + +class WandbCallback(Callback): + """A custom W&B metric logger for the trainer.""" + + def __init__(self, log_batch_frequency: int = None) -> None: + """Short summary. + + Args: + log_batch_frequency (int): If None, metrics will be logged every epoch. + If set to an integer, callback will log every metrics every log_batch_frequency. + + """ + super().__init__() + self.log_batch_frequency = log_batch_frequency + + def _on_batch_end(self, batch: int, logs: Dict) -> None: + if self.log_batch_frequency and batch % self.log_batch_frequency == 0: + wandb.log(logs, commit=True) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Logs training metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Logs validation metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Logs at epoch end.""" + wandb.log(logs, commit=True) + + +class WandbImageLogger(Callback): + """Custom W&B callback for image logging.""" + + def __init__( + self, + example_indices: Optional[List] = None, + num_examples: int = 4, + transfroms: Optional[Callable] = None, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to + None. + + """ + + super().__init__() + self.example_indices = example_indices + self.num_examples = num_examples + self.transfroms = transfroms + if self.transfroms is None: + self.transforms = Compose([Transpose()]) + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + data_loader = self.model.data_loaders["val"] + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(data_loader.dataset.data), self.num_examples + ) + self.val_images = data_loader.dataset.data[self.example_indices] + self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for i, image in enumerate(self.val_images): + image = self.transforms(image) + pred, conf = self.model.predict_on_image(image) + ground_truth = self.model.mapper(int(self.val_targets[i])) + caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" + images.append(wandb.Image(image, caption=caption)) + + wandb.log({"examples": images}, commit=False) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/trainer/population_based_training/__init__.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/trainer/population_based_training/population_based_training.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py new file mode 100644 index 0000000..a75ae8f --- /dev/null +++ b/src/training/trainer/train.py @@ -0,0 +1,216 @@ +"""Training script for PyTorch models.""" + +from pathlib import Path +import time +from typing import Dict, List, Optional, Tuple, Type + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback, CallbackList +from training.trainer.util import RunningAverage +import wandb + +from text_recognizer.models import Model + + +torch.backends.cudnn.benchmark = True +np.random.seed(4711) +torch.manual_seed(4711) +torch.cuda.manual_seed(4711) + + +class Trainer: + """Trainer for training PyTorch models.""" + + def __init__( + self, + model: Type[Model], + model_dir: Path, + train_args: Dict, + callbacks: CallbackList, + checkpoint_path: Optional[Path] = None, + ) -> None: + """Initialization of the Trainer. + + Args: + model (Type[Model]): A model object. + model_dir (Path): Path to the model directory. + train_args (Dict): The training arguments. + callbacks (CallbackList): List of callbacks to be called. + checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + + """ + self.model = model + self.model_dir = model_dir + self.checkpoint_path = checkpoint_path + self.start_epoch = 1 + self.epochs = train_args["epochs"] + self.callbacks = callbacks + + if self.checkpoint_path is not None: + self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + + # Parse the name of the experiment. + experiment_dir = str(self.model_dir.parents[1]).split("/") + self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] + + def training_step( + self, + batch: int, + samples: Tuple[Tensor, Tensor], + loss_avg: Type[RunningAverage], + ) -> Dict: + """Performs the training step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + # Forward pass. + # Get the network prediction. + output = self.model.network(data) + + # Compute the loss. + loss = self.model.criterion(output, targets) + + # Backward pass. + # Clear the previous gradients. + self.model.optimizer.zero_grad() + + # Compute the gradients. + loss.backward() + + # Perform updates using calculated gradients. + self.model.optimizer.step() + + # Compute metrics. + loss_avg.update(loss.item()) + output = output.data.cpu() + targets = targets.data.cpu() + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + metrics["loss"] = loss_avg() + return metrics + + def train(self) -> None: + """Runs the training loop for one epoch.""" + # Set model to traning mode. + self.model.train() + + # Running average for the loss. + loss_avg = RunningAverage() + + data_loader = self.model.data_loaders["train"] + + for batch, samples in enumerate(data_loader): + self.callbacks.on_train_batch_begin(batch) + metrics = self.training_step(batch, samples, loss_avg) + self.callbacks.on_train_batch_end(batch, logs=metrics) + + @torch.no_grad() + def validation_step( + self, + batch: int, + samples: Tuple[Tensor, Tensor], + loss_avg: Type[RunningAverage], + ) -> Dict: + """Performs the validation step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + # Forward pass. + # Get the network prediction. + output = self.model.network(data) + + # Compute the loss. + loss = self.model.criterion(output, targets) + + # Compute metrics. + loss_avg.update(loss.item()) + output = output.data.cpu() + targets = targets.data.cpu() + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + metrics["loss"] = loss.item() + + return metrics + + def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: + log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") + logger.debug( + log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) + ) + + def validate(self, epoch: Optional[int] = None) -> Dict: + """Runs the validation loop for one epoch.""" + # Set model to eval mode. + self.model.eval() + + # Running average for the loss. + data_loader = self.model.data_loaders["val"] + + # Running average for the loss. + loss_avg = RunningAverage() + + # Summary for the current eval loop. + summary = [] + + for batch, samples in enumerate(data_loader): + self.callbacks.on_validation_batch_begin(batch) + metrics = self.validation_step(batch, samples, loss_avg) + self.callbacks.on_validation_batch_end(batch, logs=metrics) + summary.append(metrics) + + # Compute mean of all metrics. + metrics_mean = { + "val_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] + } + self._log_val_metric(metrics_mean, epoch) + + return metrics_mean + + def fit(self) -> None: + """Runs the training and evaluation loop.""" + + logger.debug(f"Running an experiment called {self.experiment_name}.") + + # Set start time. + t_start = time.time() + + self.callbacks.on_fit_begin() + + # Run the training loop. + for epoch in range(self.start_epoch, self.epochs + 1): + self.callbacks.on_epoch_begin(epoch) + + # Perform one training pass over the training set. + self.train() + + # Evaluate the model on the validation set. + val_metrics = self.validate(epoch) + + self.callbacks.on_epoch_end(epoch, logs=val_metrics) + + if self.model.stop_training: + break + + # Calculate the total training time. + t_end = time.time() + t_training = t_end - t_start + + self.callbacks.on_fit_end() + + logger.info(f"Training took {t_training:.2f} s.") diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py new file mode 100644 index 0000000..132b2dc --- /dev/null +++ b/src/training/trainer/util.py @@ -0,0 +1,19 @@ +"""Utility functions for training neural networks.""" + + +class RunningAverage: + """Maintains a running average.""" + + def __init__(self) -> None: + """Initializes the parameters.""" + self.steps = 0 + self.total = 0 + + def update(self, val: float) -> None: + """Updates the parameters.""" + self.total += val + self.steps += 1 + + def __call__(self) -> float: + """Computes the running average.""" + return self.total / float(self.steps) diff --git a/src/training/util.py b/src/training/util.py deleted file mode 100644 index 132b2dc..0000000 --- a/src/training/util.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Utility functions for training neural networks.""" - - -class RunningAverage: - """Maintains a running average.""" - - def __init__(self) -> None: - """Initializes the parameters.""" - self.steps = 0 - self.total = 0 - - def update(self, val: float) -> None: - """Updates the parameters.""" - self.total += val - self.steps += 1 - - def __call__(self) -> float: - """Computes the running average.""" - return self.total / float(self.steps) -- cgit v1.2.3-70-g09d2