summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.flake82
-rw-r--r--README.md2
-rw-r--r--poetry.lock384
-rw-r--r--pyproject.toml5
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb639
-rw-r--r--src/notebooks/01-look-at-emnist.ipynb25
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py43
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
-rw-r--r--src/text_recognizer/models/base.py45
-rw-r--r--src/text_recognizer/models/character_model.py8
-rw-r--r--src/text_recognizer/networks/__init__.py3
-rw-r--r--src/text_recognizer/networks/lenet.py17
-rw-r--r--src/text_recognizer/networks/misc.py20
-rw-r--r--src/text_recognizer/networks/mlp.py18
-rw-r--r--src/text_recognizer/networks/residual_network.py314
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.ptbin14485310 -> 14485362 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.ptbin1704174 -> 11625484 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.ptbin0 -> 28654593 bytes
-rw-r--r--src/training/experiments/sample_experiment.yml37
-rw-r--r--src/training/prepare_experiments.py6
-rw-r--r--src/training/run_experiment.py19
-rw-r--r--src/training/trainer/__init__.py2
-rw-r--r--src/training/trainer/callbacks/__init__.py (renamed from src/training/callbacks/__init__.py)2
-rw-r--r--src/training/trainer/callbacks/base.py (renamed from src/training/callbacks/base.py)50
-rw-r--r--src/training/trainer/callbacks/early_stopping.py (renamed from src/training/callbacks/early_stopping.py)5
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py (renamed from src/training/callbacks/lr_schedulers.py)12
-rw-r--r--src/training/trainer/callbacks/progress_bar.py61
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py (renamed from src/training/callbacks/wandb_callbacks.py)8
-rw-r--r--src/training/trainer/population_based_training/__init__.py (renamed from src/training/population_based_training/__init__.py)0
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py (renamed from src/training/population_based_training/population_based_training.py)0
-rw-r--r--src/training/trainer/train.py (renamed from src/training/train.py)87
-rw-r--r--src/training/trainer/util.py (renamed from src/training/util.py)0
32 files changed, 1452 insertions, 371 deletions
diff --git a/.flake8 b/.flake8
index a27b644..c64d7ca 100644
--- a/.flake8
+++ b/.flake8
@@ -1,6 +1,6 @@
[flake8]
select = ANN,B,B9,BLK,C,D,DAR,E,F,I,S,W
-ignore = E203,E501,W503,ANN101,F401,D202,S404
+ignore = E203,E501,W503,ANN101,ANN002,ANN003,F401,D202,S404,D107
max-line-length = 120
max-complexity = 10
application-import-names = text_recognizer,tests
diff --git a/README.md b/README.md
index 328cfda..844f2e0 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,7 @@ TBC
- [x] Implement wandb
- [x] Implement lr scheduler as a callback
- [x] Implement save checkpoint callback
- - [ ] Implement TQDM progress bar (Low priority)
+ - [x] Implement TQDM progress bar (Low priority)
- [ ] Check that dataset exists, otherwise download it form the web. Do this in run_experiment.py.
- [x] Create repr func for data loaders
- [ ] Be able to restart with lr scheduler (May skip this BS)
diff --git a/poetry.lock b/poetry.lock
index f174cf1..89851d9 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -24,6 +24,23 @@ python-versions = "*"
version = "0.1.0"
[[package]]
+category = "dev"
+description = "The secure Argon2 password hashing algorithm."
+name = "argon2-cffi"
+optional = false
+python-versions = "*"
+version = "20.1.0"
+
+[package.dependencies]
+cffi = ">=1.0.0"
+six = "*"
+
+[package.extras]
+dev = ["coverage (>=5.0.2)", "hypothesis", "pytest", "sphinx", "wheel", "pre-commit"]
+docs = ["sphinx"]
+tests = ["coverage (>=5.0.2)", "hypothesis", "pytest"]
+
+[[package]]
category = "main"
description = "Atomic file writes."
marker = "sys_platform == \"win32\""
@@ -130,7 +147,7 @@ description = "When they're not builtins, they're boltons."
name = "boltons"
optional = false
python-versions = "*"
-version = "20.2.0"
+version = "20.2.1"
[[package]]
category = "main"
@@ -141,6 +158,17 @@ python-versions = "*"
version = "2020.6.20"
[[package]]
+category = "dev"
+description = "Foreign Function Interface for Python calling C code."
+name = "cffi"
+optional = false
+python-versions = "*"
+version = "1.14.2"
+
+[package.dependencies]
+pycparser = "*"
+
+[[package]]
category = "main"
description = "Universal encoding detector for Python 2 and 3"
name = "chardet"
@@ -166,7 +194,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "0.4.3"
[[package]]
-category = "dev"
+category = "main"
description = "Updated configparser from Python 3.8 for Python 2.6+."
name = "configparser"
optional = false
@@ -254,7 +282,7 @@ typing-inspect = "*"
dev = ["coverage", "cuvner", "pytest", "tox", "versioneer", "black", "pylint", "pex", "bump2version", "docutils", "check-manifest", "readme-renderer", "pygments", "isort", "mypy", "pytest-sphinx", "towncrier", "marshmallow-union", "marshmallow-enum", "twine", "wheel"]
[[package]]
-category = "dev"
+category = "main"
description = "Python bindings for the docker credentials store API"
name = "docker-pycreds"
optional = false
@@ -418,7 +446,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
version = "0.18.2"
[[package]]
-category = "dev"
+category = "main"
description = "Git Object Database"
name = "gitdb"
optional = false
@@ -429,7 +457,7 @@ version = "4.0.5"
smmap = ">=3.0.1,<4"
[[package]]
-category = "dev"
+category = "main"
description = "Python Git Library"
name = "gitpython"
optional = false
@@ -457,7 +485,7 @@ six = ">=1.7"
test = ["mock (>=2.0.0)", "pytest (<5.0)"]
[[package]]
-category = "dev"
+category = "main"
description = "GraphQL client for Python"
name = "gql"
optional = false
@@ -471,7 +499,7 @@ requests = ">=2.12,<3"
six = ">=1.10.0"
[[package]]
-category = "dev"
+category = "main"
description = "GraphQL implementation for Python"
name = "graphql-core"
optional = false
@@ -566,8 +594,8 @@ category = "dev"
description = "IPython: Productive Interactive Computing"
name = "ipython"
optional = false
-python-versions = ">=3.6"
-version = "7.16.1"
+python-versions = ">=3.7"
+version = "7.17.0"
[package.dependencies]
appnope = "*"
@@ -796,9 +824,10 @@ description = "Python plotting package"
name = "matplotlib"
optional = false
python-versions = ">=3.6"
-version = "3.3.0"
+version = "3.3.1"
[package.dependencies]
+certifi = ">=2020.06.20"
cycler = ">=0.10"
kiwisolver = ">=1.0.1"
numpy = ">=1.15"
@@ -961,10 +990,11 @@ description = "A web-based notebook environment for interactive computing"
name = "notebook"
optional = false
python-versions = ">=3.5"
-version = "6.0.3"
+version = "6.1.3"
[package.dependencies]
Send2Trash = "*"
+argon2-cffi = "*"
ipykernel = "*"
ipython-genutils = "*"
jinja2 = "*"
@@ -974,12 +1004,13 @@ nbconvert = "*"
nbformat = "*"
prometheus-client = "*"
pyzmq = ">=17"
-terminado = ">=0.8.1"
+terminado = ">=0.8.3"
tornado = ">=5.0"
traitlets = ">=4.2.1"
[package.extras]
-test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-exclude", "selenium", "pytest", "pytest-cov", "nose-exclude"]
+docs = ["sphinx", "nbsphinx", "sphinxcontrib-github-alt"]
+test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-exclude", "selenium", "pytest", "pytest-cov", "requests-unixsocket"]
[[package]]
category = "main"
@@ -990,7 +1021,7 @@ python-versions = ">=3.6"
version = "1.19.1"
[[package]]
-category = "dev"
+category = "main"
description = "Python Bindings for the NVIDIA Management Library"
name = "nvidia-ml-py3"
optional = false
@@ -1002,11 +1033,11 @@ category = "main"
description = "Wrapper package for OpenCV python bindings."
name = "opencv-python"
optional = false
-python-versions = "*"
-version = "4.3.0.36"
+python-versions = ">=3.5"
+version = "4.4.0.42"
[package.dependencies]
-numpy = ">=1.11.1"
+numpy = ">=1.13.1"
[[package]]
category = "main"
@@ -1048,7 +1079,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "0.8.0"
[[package]]
-category = "dev"
+category = "main"
description = "File system general utilities"
name = "pathtools"
optional = false
@@ -1119,7 +1150,7 @@ version = "0.8.0"
twisted = ["twisted"]
[[package]]
-category = "dev"
+category = "main"
description = "Promises/A+ implementation for Python"
name = "promise"
optional = false
@@ -1138,13 +1169,13 @@ description = "Library for building powerful interactive command lines in Python
name = "prompt-toolkit"
optional = false
python-versions = ">=3.6.1"
-version = "3.0.5"
+version = "3.0.6"
[package.dependencies]
wcwidth = "*"
[[package]]
-category = "dev"
+category = "main"
description = "Cross-platform lib for process and system monitoring in Python."
name = "psutil"
optional = false
@@ -1180,6 +1211,14 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "2.6.0"
[[package]]
+category = "dev"
+description = "C parser in Python"
+name = "pycparser"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+version = "2.20"
+
+[[package]]
category = "main"
description = "Python docstring style checker"
name = "pydocstyle"
@@ -1257,7 +1296,7 @@ description = "Pytest plugin for measuring coverage."
name = "pytest-cov"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "2.10.0"
+version = "2.10.1"
[package.dependencies]
coverage = ">=4.4"
@@ -1298,7 +1337,7 @@ marker = "python_version == \"3.7\""
name = "pytype"
optional = false
python-versions = "<3.9,>=3.5"
-version = "2020.7.24"
+version = "2020.8.10"
[package.dependencies]
attrs = "*"
@@ -1335,7 +1374,7 @@ python-versions = "*"
version = "0.5.7"
[[package]]
-category = "dev"
+category = "main"
description = "YAML parser and emitter for Python"
name = "pyyaml"
optional = false
@@ -1348,7 +1387,7 @@ description = "Python bindings for 0MQ"
name = "pyzmq"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*"
-version = "19.0.1"
+version = "19.0.2"
[[package]]
category = "dev"
@@ -1452,12 +1491,12 @@ python-versions = "*"
version = "1.5.0"
[[package]]
-category = "dev"
+category = "main"
description = "Python client for Sentry (https://sentry.io)"
name = "sentry-sdk"
optional = false
python-versions = "*"
-version = "0.16.2"
+version = "0.16.5"
[package.dependencies]
certifi = "*"
@@ -1479,7 +1518,7 @@ sqlalchemy = ["sqlalchemy (>=1.2)"]
tornado = ["tornado (>=5)"]
[[package]]
-category = "dev"
+category = "main"
description = "A generator library for concise, unambiguous and URL-safe UUIDs."
name = "shortuuid"
optional = false
@@ -1495,7 +1534,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
version = "1.15.0"
[[package]]
-category = "dev"
+category = "main"
description = "A pure Python implementation of a sliding window memory map manager"
name = "smmap"
optional = false
@@ -1516,7 +1555,7 @@ description = "Python documentation generator"
name = "sphinx"
optional = false
python-versions = ">=3.5"
-version = "3.1.2"
+version = "3.2.1"
[package.dependencies]
Jinja2 = ">=2.3"
@@ -1655,7 +1694,7 @@ python = "<3.8"
version = ">=1.7.0"
[[package]]
-category = "dev"
+category = "main"
description = "A backport of the subprocess module from Python 3 for use on 2.x."
name = "subprocess32"
optional = false
@@ -1699,8 +1738,8 @@ category = "main"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
name = "torch"
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.6.0"
-version = "1.5.1"
+python-versions = ">=3.6.1"
+version = "1.6.0"
[package.dependencies]
future = "*"
@@ -1720,12 +1759,12 @@ description = "image and video datasets and models for torch deep learning"
name = "torchvision"
optional = false
python-versions = "*"
-version = "0.6.1"
+version = "0.7.0"
[package.dependencies]
numpy = "*"
pillow = ">=4.1.1"
-torch = "1.5.1"
+torch = "1.6.0"
[package.extras]
scipy = ["scipy"]
@@ -1744,7 +1783,7 @@ description = "Fast, Extensible Progress Meter"
name = "tqdm"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
-version = "4.48.0"
+version = "4.48.2"
[package.extras]
dev = ["py-make (>=0.1.0)", "twine", "argopt", "pydoc-markdown"]
@@ -1819,12 +1858,12 @@ secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "pyOpenSSL (>=0
socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7,<2.0)"]
[[package]]
-category = "dev"
+category = "main"
description = "A CLI and library for interacting with the Weights and Biases API."
name = "wandb"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "0.9.4"
+version = "0.9.5"
[package.dependencies]
Click = ">=7.0"
@@ -1849,7 +1888,7 @@ gcp = ["google-cloud-storage"]
kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"]
[[package]]
-category = "dev"
+category = "main"
description = "Filesystem events monitoring"
name = "watchdog"
optional = false
@@ -1931,7 +1970,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"]
testing = ["jaraco.itertools", "func-timeout"]
[metadata]
-content-hash = "90078f3c0d16b4e2e67b4f5ad22219067d8936dfca527d130ec0a179d2c55cb0"
+content-hash = "8aea4adc15683cb4bdd650d199725418e7778687d311bcb6a02db94fa05719f7"
lock-version = "1.0"
python-versions = "^3.7"
@@ -1948,6 +1987,24 @@ appnope = [
{file = "appnope-0.1.0-py2.py3-none-any.whl", hash = "sha256:5b26757dc6f79a3b7dc9fab95359328d5747fcb2409d331ea66d0272b90ab2a0"},
{file = "appnope-0.1.0.tar.gz", hash = "sha256:8b995ffe925347a2138d7ac0fe77155e4311a0ea6d6da4f5128fe4b3cbe5ed71"},
]
+argon2-cffi = [
+ {file = "argon2-cffi-20.1.0.tar.gz", hash = "sha256:d8029b2d3e4b4cea770e9e5a0104dd8fa185c1724a0f01528ae4826a6d25f97d"},
+ {file = "argon2_cffi-20.1.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:6ea92c980586931a816d61e4faf6c192b4abce89aa767ff6581e6ddc985ed003"},
+ {file = "argon2_cffi-20.1.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:05a8ac07c7026542377e38389638a8a1e9b78f1cd8439cd7493b39f08dd75fbf"},
+ {file = "argon2_cffi-20.1.0-cp27-cp27m-win32.whl", hash = "sha256:0bf066bc049332489bb2d75f69216416329d9dc65deee127152caeb16e5ce7d5"},
+ {file = "argon2_cffi-20.1.0-cp27-cp27m-win_amd64.whl", hash = "sha256:57358570592c46c420300ec94f2ff3b32cbccd10d38bdc12dc6979c4a8484fbc"},
+ {file = "argon2_cffi-20.1.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:7d455c802727710e9dfa69b74ccaab04568386ca17b0ad36350b622cd34606fe"},
+ {file = "argon2_cffi-20.1.0-cp35-abi3-manylinux1_x86_64.whl", hash = "sha256:b160416adc0f012fb1f12588a5e6954889510f82f698e23ed4f4fa57f12a0647"},
+ {file = "argon2_cffi-20.1.0-cp35-cp35m-win32.whl", hash = "sha256:9bee3212ba4f560af397b6d7146848c32a800652301843df06b9e8f68f0f7361"},
+ {file = "argon2_cffi-20.1.0-cp35-cp35m-win_amd64.whl", hash = "sha256:392c3c2ef91d12da510cfb6f9bae52512a4552573a9e27600bdb800e05905d2b"},
+ {file = "argon2_cffi-20.1.0-cp36-cp36m-win32.whl", hash = "sha256:ba7209b608945b889457f949cc04c8e762bed4fe3fec88ae9a6b7765ae82e496"},
+ {file = "argon2_cffi-20.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:da7f0445b71db6d3a72462e04f36544b0de871289b0bc8a7cc87c0f5ec7079fa"},
+ {file = "argon2_cffi-20.1.0-cp37-abi3-macosx_10_6_intel.whl", hash = "sha256:cc0e028b209a5483b6846053d5fd7165f460a1f14774d79e632e75e7ae64b82b"},
+ {file = "argon2_cffi-20.1.0-cp37-cp37m-win32.whl", hash = "sha256:18dee20e25e4be86680b178b35ccfc5d495ebd5792cd00781548d50880fee5c5"},
+ {file = "argon2_cffi-20.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:6678bb047373f52bcff02db8afab0d2a77d83bde61cfecea7c5c62e2335cb203"},
+ {file = "argon2_cffi-20.1.0-cp38-cp38-win32.whl", hash = "sha256:77e909cc756ef81d6abb60524d259d959bab384832f0c651ed7dcb6e5ccdbb78"},
+ {file = "argon2_cffi-20.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:9dfd5197852530294ecb5795c97a823839258dfd5eb9420233c7cfedec2058f2"},
+]
atomicwrites = [
{file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"},
{file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
@@ -1982,13 +2039,43 @@ blessings = [
{file = "blessings-1.7.tar.gz", hash = "sha256:98e5854d805f50a5b58ac2333411b0482516a8210f23f43308baeb58d77c157d"},
]
boltons = [
- {file = "boltons-20.2.0-py2.py3-none-any.whl", hash = "sha256:1567a4ab991a52be4e9a778c6c433e50237018b322759146b07732b71e57ef67"},
- {file = "boltons-20.2.0.tar.gz", hash = "sha256:d367506c0b32042bb1ee3bf7899f2dcc8492dceb42ce3727b89e174d85bffe6e"},
+ {file = "boltons-20.2.1-py2.py3-none-any.whl", hash = "sha256:3dd8a8e3c1886e7f7ba3422b50f55a66e1700161bf01b919d098e7d96dd2d9b6"},
+ {file = "boltons-20.2.1.tar.gz", hash = "sha256:dd362291a460cc1e0c2e91cc6a60da3036ced77099b623112e8f833e6734bdc5"},
]
certifi = [
{file = "certifi-2020.6.20-py2.py3-none-any.whl", hash = "sha256:8fc0819f1f30ba15bdb34cceffb9ef04d99f420f68eb75d901e9560b8749fc41"},
{file = "certifi-2020.6.20.tar.gz", hash = "sha256:5930595817496dd21bb8dc35dad090f1c2cd0adfaf21204bf6732ca5d8ee34d3"},
]
+cffi = [
+ {file = "cffi-1.14.2-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:da9d3c506f43e220336433dffe643fbfa40096d408cb9b7f2477892f369d5f82"},
+ {file = "cffi-1.14.2-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:23e44937d7695c27c66a54d793dd4b45889a81b35c0751ba91040fe825ec59c4"},
+ {file = "cffi-1.14.2-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:0da50dcbccd7cb7e6c741ab7912b2eff48e85af217d72b57f80ebc616257125e"},
+ {file = "cffi-1.14.2-cp27-cp27m-win32.whl", hash = "sha256:76ada88d62eb24de7051c5157a1a78fd853cca9b91c0713c2e973e4196271d0c"},
+ {file = "cffi-1.14.2-cp27-cp27m-win_amd64.whl", hash = "sha256:15a5f59a4808f82d8ec7364cbace851df591c2d43bc76bcbe5c4543a7ddd1bf1"},
+ {file = "cffi-1.14.2-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:e4082d832e36e7f9b2278bc774886ca8207346b99f278e54c9de4834f17232f7"},
+ {file = "cffi-1.14.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:57214fa5430399dffd54f4be37b56fe22cedb2b98862550d43cc085fb698dc2c"},
+ {file = "cffi-1.14.2-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:6843db0343e12e3f52cc58430ad559d850a53684f5b352540ca3f1bc56df0731"},
+ {file = "cffi-1.14.2-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:577791f948d34d569acb2d1add5831731c59d5a0c50a6d9f629ae1cefd9ca4a0"},
+ {file = "cffi-1.14.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:8662aabfeab00cea149a3d1c2999b0731e70c6b5bac596d95d13f643e76d3d4e"},
+ {file = "cffi-1.14.2-cp35-cp35m-win32.whl", hash = "sha256:837398c2ec00228679513802e3744d1e8e3cb1204aa6ad408b6aff081e99a487"},
+ {file = "cffi-1.14.2-cp35-cp35m-win_amd64.whl", hash = "sha256:bf44a9a0141a082e89c90e8d785b212a872db793a0080c20f6ae6e2a0ebf82ad"},
+ {file = "cffi-1.14.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:29c4688ace466a365b85a51dcc5e3c853c1d283f293dfcc12f7a77e498f160d2"},
+ {file = "cffi-1.14.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:99cc66b33c418cd579c0f03b77b94263c305c389cb0c6972dac420f24b3bf123"},
+ {file = "cffi-1.14.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:65867d63f0fd1b500fa343d7798fa64e9e681b594e0a07dc934c13e76ee28fb1"},
+ {file = "cffi-1.14.2-cp36-cp36m-win32.whl", hash = "sha256:f5033952def24172e60493b68717792e3aebb387a8d186c43c020d9363ee7281"},
+ {file = "cffi-1.14.2-cp36-cp36m-win_amd64.whl", hash = "sha256:7057613efefd36cacabbdbcef010e0a9c20a88fc07eb3e616019ea1692fa5df4"},
+ {file = "cffi-1.14.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6539314d84c4d36f28d73adc1b45e9f4ee2a89cdc7e5d2b0a6dbacba31906798"},
+ {file = "cffi-1.14.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:672b539db20fef6b03d6f7a14b5825d57c98e4026401fce838849f8de73fe4d4"},
+ {file = "cffi-1.14.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95e9094162fa712f18b4f60896e34b621df99147c2cee216cfa8f022294e8e9f"},
+ {file = "cffi-1.14.2-cp37-cp37m-win32.whl", hash = "sha256:b9aa9d8818c2e917fa2c105ad538e222a5bce59777133840b93134022a7ce650"},
+ {file = "cffi-1.14.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e4b9b7af398c32e408c00eb4e0d33ced2f9121fd9fb978e6c1b57edd014a7d15"},
+ {file = "cffi-1.14.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e613514a82539fc48291d01933951a13ae93b6b444a88782480be32245ed4afa"},
+ {file = "cffi-1.14.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:9b219511d8b64d3fa14261963933be34028ea0e57455baf6781fe399c2c3206c"},
+ {file = "cffi-1.14.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c0b48b98d79cf795b0916c57bebbc6d16bb43b9fc9b8c9f57f4cf05881904c75"},
+ {file = "cffi-1.14.2-cp38-cp38-win32.whl", hash = "sha256:15419020b0e812b40d96ec9d369b2bc8109cc3295eac6e013d3261343580cc7e"},
+ {file = "cffi-1.14.2-cp38-cp38-win_amd64.whl", hash = "sha256:12a453e03124069b6896107ee133ae3ab04c624bb10683e1ed1c1663df17c13c"},
+ {file = "cffi-1.14.2.tar.gz", hash = "sha256:ae8f34d50af2c2154035984b8b5fc5d9ed63f32fe615646ab435b05b132ca91b"},
+]
chardet = [
{file = "chardet-3.0.4-py2.py3-none-any.whl", hash = "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"},
{file = "chardet-3.0.4.tar.gz", hash = "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae"},
@@ -2186,8 +2273,8 @@ ipykernel = [
{file = "ipykernel-5.3.4.tar.gz", hash = "sha256:9b2652af1607986a1b231c62302d070bc0534f564c393a5d9d130db9abbbe89d"},
]
ipython = [
- {file = "ipython-7.16.1-py3-none-any.whl", hash = "sha256:2dbcc8c27ca7d3cfe4fcdff7f45b27f9a8d3edfa70ff8024a71c7a8eb5f09d64"},
- {file = "ipython-7.16.1.tar.gz", hash = "sha256:9f4fcb31d3b2c533333893b9172264e4821c1ac91839500f31bd43f2c59b3ccf"},
+ {file = "ipython-7.17.0-py3-none-any.whl", hash = "sha256:5a8f159ca8b22b9a0a1f2a28befe5ad2b703339afb58c2ffe0d7c8d7a3af5999"},
+ {file = "ipython-7.17.0.tar.gz", hash = "sha256:b70974aaa2674b05eb86a910c02ed09956a33f2dd6c71afc60f0b128a77e7f28"},
]
ipython-genutils = [
{file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
@@ -2292,27 +2379,24 @@ marshmallow = [
{file = "marshmallow-3.7.1.tar.gz", hash = "sha256:a2a5eefb4b75a3b43f05be1cca0b6686adf56af7465c3ca629e5ad8d1e1fe13d"},
]
matplotlib = [
- {file = "matplotlib-3.3.0-1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b0786ac32983191fcd9cc0230b4ec2f8b3c25dee9beca46ca506c5d6cc5c593d"},
- {file = "matplotlib-3.3.0-1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:f9753c6292d5a1fe46828feb38d1de1820e3ea109a5fea0b6ea1dca6e9d0b220"},
- {file = "matplotlib-3.3.0-1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6aa7ea00ad7d898704ffed46e83efd7ec985beba57f507c957979f080678b9ea"},
- {file = "matplotlib-3.3.0-1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:19cf4db0272da286863a50406f6430101af129f288c421b1a7f33ddfc8d0180f"},
- {file = "matplotlib-3.3.0-1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ebb6168c9330309b1f3360d36c481d8cd621a490cf2a69c9d6625b2a76777c12"},
- {file = "matplotlib-3.3.0-1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:695b4165520bdfe381d15a6db03778babb265fee7affdc43e169a881f3f329bc"},
- {file = "matplotlib-3.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc651261b7044ffc3b1e2f9af17b1ef4c6a12fc080b5a7353ef0b53a50be28"},
- {file = "matplotlib-3.3.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:e3868686f3023644523df486fc224b0af4349f3cdb933b0a71f261a574d7b65f"},
- {file = "matplotlib-3.3.0-cp36-cp36m-win32.whl", hash = "sha256:7c9adba58a67d23cc131c4189da56cb1d0f18a237c43188d831a44e4fc5df15a"},
- {file = "matplotlib-3.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:855bb281f3cc8e23ef66064a2beb229674fdf785638091fc82a172e3e84c2780"},
- {file = "matplotlib-3.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7ce8f5364c74aac06abad84d8744d659bd86036e86c4ebf14c75ae4292597b46"},
- {file = "matplotlib-3.3.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:605e4d43b421524ad955a56535391e02866d07bce27c644e2c99e25fb59d63d1"},
- {file = "matplotlib-3.3.0-cp37-cp37m-win32.whl", hash = "sha256:cef05e9a2302f96d6f0666ee70ac7715cbc12e3802d8b8eb80bacd6ab81a0a24"},
- {file = "matplotlib-3.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:bf8d527a2eb9a5db1c9e5e405d1b1c4e66be983620c9ce80af6aae430d9a0c9c"},
- {file = "matplotlib-3.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c06ea133b44805d42f2507cb3503f6647b0c7918f1900b5063f5a8a69c63f6d2"},
- {file = "matplotlib-3.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:09b4096748178bcc764b81587b00ab720eac24965d2bf44ecbe09dcf4e8ed253"},
- {file = "matplotlib-3.3.0-cp38-cp38-win32.whl", hash = "sha256:c1f850908600efa60f81ad14eedbaf7cb17185a2c6d26586ae826ae5ce21f6e0"},
- {file = "matplotlib-3.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2a9d10930406748b50f60c5fa74c399a1c1080aa6ce6e3fe5f38473b02f6f06d"},
- {file = "matplotlib-3.3.0-cp39-cp39-win32.whl", hash = "sha256:244a9088140a4c540e0a2db9c8ada5ad12520efded592a46e5bc43ff8f0fd0aa"},
- {file = "matplotlib-3.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:f74c39621b03cec7bc08498f140192ac26ca940ef20beac6dfad3714d2298b2a"},
- {file = "matplotlib-3.3.0.tar.gz", hash = "sha256:24e8db94948019d531ce0bcd637ac24b1c8f6744ac86d2aa0eb6dbaeb1386f82"},
+ {file = "matplotlib-3.3.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:282f8a077a1217f9f2ac178596f27c1ae94abbc6e7b785e1b8f25e83918e9199"},
+ {file = "matplotlib-3.3.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:83ae7261f4d5ab387be2caee29c4f499b1566f31c8ac97a0b8ab61afd9e3da92"},
+ {file = "matplotlib-3.3.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1f9cf2b8500b833714a193cb24281153f5072d55b2e486009f1e81f0b7da3410"},
+ {file = "matplotlib-3.3.1-cp36-cp36m-win32.whl", hash = "sha256:0dc15e1ad84ec06bf0c315e6c4c2cced13a21ce4c2b4955bb75097064a4b1e92"},
+ {file = "matplotlib-3.3.1-cp36-cp36m-win_amd64.whl", hash = "sha256:ffbae66e2db70dc330cb3299525f97e1c0efdfc763e04e1a4e08f968c7ad21f0"},
+ {file = "matplotlib-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88c6ab4a32a7447dad236b8371612aaba5c967d632ff11999e0478dd687f2c58"},
+ {file = "matplotlib-3.3.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:cc2d6b47c8fee89da982a312b54949ec0cd6a7976a8cafb5b62dea6c9883a14d"},
+ {file = "matplotlib-3.3.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:636c6330a7dcb18bac114dbeaff314fbbb0c11682f9a9601de69a50e331d18d7"},
+ {file = "matplotlib-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:73a493e340064e8fe03207d9333b68baca30d9f0da543ae4af6b6b4f13f0fe05"},
+ {file = "matplotlib-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:6739b6cd9278d5cb337df0bd4400ad37bbd04c6dc7aa2c65e1e83a02bc4cc6fd"},
+ {file = "matplotlib-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79f0c4730ad422ecb6bda814c9a9b375df36d6bd5a49eaa14e92e5f5e3e95ac3"},
+ {file = "matplotlib-3.3.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e4d6d3afc454b4afc0d9d0ed52a8fa40a1b0d8f33c8e143e49a5833a7e32266b"},
+ {file = "matplotlib-3.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:96a5e667308dbf45670370d9dffb974e73b15bac0df0b5f3fb0b0ac7a572290e"},
+ {file = "matplotlib-3.3.1-cp38-cp38-win32.whl", hash = "sha256:bd8fceaa3494b531d43b6206966ba15705638137fc2dc5da5ee560cf9476867b"},
+ {file = "matplotlib-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:1507c2a8e4662f6fa1d3ecc760782b158df8a3244ecc21c1d8dbb1cd0b3f872e"},
+ {file = "matplotlib-3.3.1-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c3619ec2a5ead430a4536ebf8c77ea55d8ce36418919f831d35bc657ed5f27e"},
+ {file = "matplotlib-3.3.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:9703bc00a94a94c4e94b2ea0fbfbc9d2bb21159733134639fd931b6606c5c47e"},
+ {file = "matplotlib-3.3.1.tar.gz", hash = "sha256:87f53bcce90772f942c2db56736788b39332d552461a5cb13f05ff45c1680f0e"},
]
mccabe = [
{file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},
@@ -2375,8 +2459,8 @@ nltk = [
{file = "nltk-3.5.zip", hash = "sha256:845365449cd8c5f9731f7cb9f8bd6fd0767553b9d53af9eb1b3abf7700936b35"},
]
notebook = [
- {file = "notebook-6.0.3-py3-none-any.whl", hash = "sha256:3edc616c684214292994a3af05eaea4cc043f6b4247d830f3a2f209fa7639a80"},
- {file = "notebook-6.0.3.tar.gz", hash = "sha256:47a9092975c9e7965ada00b9a20f0cf637d001db60d241d479f53c0be117ad48"},
+ {file = "notebook-6.1.3-py3-none-any.whl", hash = "sha256:964cc40cff68e473f3778aef9266e867f7703cb4aebdfd250f334efe02f64c86"},
+ {file = "notebook-6.1.3.tar.gz", hash = "sha256:9990d51b9931a31e681635899aeb198b4c4b41586a9e87fbfaaed1a71d0a05b6"},
]
numpy = [
{file = "numpy-1.19.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b1cca51512299841bf69add3b75361779962f9cee7d9ee3bb446d5982e925b69"},
@@ -2410,25 +2494,27 @@ nvidia-ml-py3 = [
{file = "nvidia-ml-py3-7.352.0.tar.gz", hash = "sha256:390f02919ee9d73fe63a98c73101061a6b37fa694a793abf56673320f1f51277"},
]
opencv-python = [
- {file = "opencv_python-4.3.0.36-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:c4f1e9d963c8f370284afa87fcf521cc8439a610a500bf8ede27fd64dd9050bd"},
- {file = "opencv_python-4.3.0.36-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:55e1d7a2d11c40ea5b53aabe5c4122038803c7d492505c8f93af077aa7fe2ce1"},
- {file = "opencv_python-4.3.0.36-cp35-cp35m-win32.whl", hash = "sha256:76ddc6daf8607eda1d866395dcf98526ef96f3e616d8c37ccc7629f9aaf6d4d4"},
- {file = "opencv_python-4.3.0.36-cp35-cp35m-win_amd64.whl", hash = "sha256:2fe704e35808cf6b17b793e89fd00e9ef7779f85f274666a4e092671aedd09c0"},
- {file = "opencv_python-4.3.0.36-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:2ec6502bfac01b27ac06daf7bc9f7a4f482a6a0d8e1b30e15c411d478454a19f"},
- {file = "opencv_python-4.3.0.36-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:f67c1d92ff96c6c106f786b7ef9b9ab448fa03ef28cb7bb6f0f7b857b65bc158"},
- {file = "opencv_python-4.3.0.36-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:c93b1198c85175a9fa9a9839c4da55c7ab9c5f57256f2e4211cd6c91d7d422e8"},
- {file = "opencv_python-4.3.0.36-cp36-cp36m-win32.whl", hash = "sha256:1bf486680a16d739f7852a62865b72eb7692df584694815774ba97b471b8bc3f"},
- {file = "opencv_python-4.3.0.36-cp36-cp36m-win_amd64.whl", hash = "sha256:f6fa2834d85c78865ca6e3de563916086cb8c83c3f2ef80924fcd07005f05df9"},
- {file = "opencv_python-4.3.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fa1a6d149a1a5e0bc54c737a59fe38d75384a092ae5e35f9b876fbb621f755c6"},
- {file = "opencv_python-4.3.0.36-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:4b93b5f8df187e4dba9fb25c46fa8cf342c257de144f7c86d75c06416566a199"},
- {file = "opencv_python-4.3.0.36-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:fd457deedcf153dd6805a2b4d891ac2a0969566d3755fbf48a3ffb53978c9ed1"},
- {file = "opencv_python-4.3.0.36-cp37-cp37m-win32.whl", hash = "sha256:ef4ac758a4e2caee80ef9c86b83a279d6f132c9e7ae77957cf74013928814e05"},
- {file = "opencv_python-4.3.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:d765c44827778cbe6bc8f272cd61514e8509b93fd24dd3324cd4abddf2026b11"},
- {file = "opencv_python-4.3.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:677f61332436e22f83a1e4e6f6863a760734fbc8029ba6a8ef0af4554cde6f93"},
- {file = "opencv_python-4.3.0.36-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:eb709245e56f6693d297f8818ff8e6c017fa80fdb5a923c64be623a678c7150e"},
- {file = "opencv_python-4.3.0.36-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:210ab40c8c9dadc7dc9ed7beebe2e0a2415a744f8d6857762a80c8e0fcc477c8"},
- {file = "opencv_python-4.3.0.36-cp38-cp38-win32.whl", hash = "sha256:156e2954d5b38b676e8a24d66703cf15f252e24ec49db7e842a8b5eed46074ba"},
- {file = "opencv_python-4.3.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:1ea08f22246ccd33174d59edfa3f13930bf2c28096568242090bd9d8770fb8a8"},
+ {file = "opencv-python-4.4.0.42.tar.gz", hash = "sha256:0039506845d7076e6871c0075227881a84de69799d70ed37c8704d203b740911"},
+ {file = "opencv_python-4.4.0.42-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:608dae0444065669fc26fa6bf1653072e40735b33dfa514c74a6165563a99e97"},
+ {file = "opencv_python-4.4.0.42-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:fae421571a7709ae0baa9bfd08177165bc1d56d7c79c806d12627d58a6faf2d1"},
+ {file = "opencv_python-4.4.0.42-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:a35b3a3540623090ba5fdad7ed97d0d75ca80ee55f5d7c1cecddda723665c0f8"},
+ {file = "opencv_python-4.4.0.42-cp35-cp35m-win32.whl", hash = "sha256:177f14625ea164f38b5b6f5c2b316f8ff8163e996cc0432de90f475956a9069a"},
+ {file = "opencv_python-4.4.0.42-cp35-cp35m-win_amd64.whl", hash = "sha256:093c1bfa6da24a9d4dde2d54a22b9acfb46f5cb2c50d7387356cf897f0db0ab9"},
+ {file = "opencv_python-4.4.0.42-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:f5b82cd49b560e004608ca53ce625e5167b41f0fdc610758d6989083e26b5a03"},
+ {file = "opencv_python-4.4.0.42-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:bcb24c4f82fa79f049db4bfd0da1d18a315da66a55aa3d4cde81d1ec18f0a7ff"},
+ {file = "opencv_python-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:cb00bbd41268f5fa0fa327ca30f7621a8ece983e0d8ae472e2ffe7ab1617606f"},
+ {file = "opencv_python-4.4.0.42-cp36-cp36m-win32.whl", hash = "sha256:78a0796ec15d1b41f5a87c41f339356eb04858749c8845936be532cb3436f898"},
+ {file = "opencv_python-4.4.0.42-cp36-cp36m-win_amd64.whl", hash = "sha256:34d0d2c9a80c02d55f83a67c29fc4145a9dcf1fe3ddef0535d0b0d9c7b89b8d2"},
+ {file = "opencv_python-4.4.0.42-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:505bd984aae24c489910bbd168e515580d62bc1dbdd5ee36f2c2d42803c4b795"},
+ {file = "opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:17663f0469b2944b7d4051d4b1c425235d153777f17310c6990370bbb4d12695"},
+ {file = "opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:d19cbbcdc05caf7b41e28898f05c076c94b07647b4556c8327663a40acd4e3bd"},
+ {file = "opencv_python-4.4.0.42-cp37-cp37m-win32.whl", hash = "sha256:ccd92a126d253c7bd65b36184fe097a0eea77da4d72d427e1630633bc586233e"},
+ {file = "opencv_python-4.4.0.42-cp37-cp37m-win_amd64.whl", hash = "sha256:80a51a797f71ee4a401d281749bb096370007202204bbcd1ecfc9ead58bd3b0b"},
+ {file = "opencv_python-4.4.0.42-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:02f7e31c710a7c82229fc4ad98e7e4cf265d19ab52b4451cbe7e33a840fe6595"},
+ {file = "opencv_python-4.4.0.42-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:b3ae62990faebefbc3cbc5430f7b6de57bafdcf297134113a9c6d6ccfce4438f"},
+ {file = "opencv_python-4.4.0.42-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:fec63240ea3179a2b4176a3256a99682129d75450a15bf2807904600ec64b45a"},
+ {file = "opencv_python-4.4.0.42-cp38-cp38-win32.whl", hash = "sha256:324a2c680caae9edbd843a355a2e03792cbd23faf6c24c20dd594fa9aac80765"},
+ {file = "opencv_python-4.4.0.42-cp38-cp38-win_amd64.whl", hash = "sha256:a6e1d065a45ec1bf466f47bdf767e0505b244c9470140cf8bab1dd8835f0d3ee"},
]
packaging = [
{file = "packaging-20.4-py2.py3-none-any.whl", hash = "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181"},
@@ -2500,8 +2586,8 @@ promise = [
{file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"},
]
prompt-toolkit = [
- {file = "prompt_toolkit-3.0.5-py3-none-any.whl", hash = "sha256:df7e9e63aea609b1da3a65641ceaf5bc7d05e0a04de5bd45d05dbeffbabf9e04"},
- {file = "prompt_toolkit-3.0.5.tar.gz", hash = "sha256:563d1a4140b63ff9dd587bda9557cffb2fe73650205ab6f4383092fb882e7dc8"},
+ {file = "prompt_toolkit-3.0.6-py3-none-any.whl", hash = "sha256:683397077a64cd1f750b71c05afcfc6612a7300cb6932666531e5a54f38ea564"},
+ {file = "prompt_toolkit-3.0.6.tar.gz", hash = "sha256:7630ab85a23302839a0f26b31cc24f518e6155dea1ed395ea61b42c45941b6a6"},
]
psutil = [
{file = "psutil-5.7.2-cp27-none-win32.whl", hash = "sha256:f2018461733b23f308c298653c8903d32aaad7873d25e1d228765e91ae42c3f2"},
@@ -2528,6 +2614,10 @@ pycodestyle = [
{file = "pycodestyle-2.6.0-py2.py3-none-any.whl", hash = "sha256:2295e7b2f6b5bd100585ebcb1f616591b652db8a741695b3d8f5d28bdc934367"},
{file = "pycodestyle-2.6.0.tar.gz", hash = "sha256:c58a7d2815e0e8d7972bf1803331fb0152f867bd89adf8a01dfd55085434192e"},
]
+pycparser = [
+ {file = "pycparser-2.20-py2.py3-none-any.whl", hash = "sha256:7582ad22678f0fcd81102833f60ef8d0e57288b6b5fb00323d101be910e35705"},
+ {file = "pycparser-2.20.tar.gz", hash = "sha256:2d475327684562c3a96cc71adf7dc8c4f0565175cf86b6d7a404ff4c771f15f0"},
+]
pydocstyle = [
{file = "pydocstyle-5.0.2-py3-none-any.whl", hash = "sha256:da7831660b7355307b32778c4a0dbfb137d89254ef31a2b2978f50fc0b4d7586"},
{file = "pydocstyle-5.0.2.tar.gz", hash = "sha256:f4f5d210610c2d153fae39093d44224c17429e2ad7da12a8b419aba5c2f614b5"},
@@ -2552,8 +2642,8 @@ pytest = [
{file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"},
]
pytest-cov = [
- {file = "pytest-cov-2.10.0.tar.gz", hash = "sha256:1a629dc9f48e53512fcbfda6b07de490c374b0c83c55ff7a1720b3fccff0ac87"},
- {file = "pytest_cov-2.10.0-py2.py3-none-any.whl", hash = "sha256:6e6d18092dce6fad667cd7020deed816f858ad3b49d5b5e2b1cc1c97a4dba65c"},
+ {file = "pytest-cov-2.10.1.tar.gz", hash = "sha256:47bd0ce14056fdd79f93e1713f88fad7bdcc583dcd7783da86ef2f085a0bb88e"},
+ {file = "pytest_cov-2.10.1-py2.py3-none-any.whl", hash = "sha256:45ec2d5182f89a81fc3eb29e3d1ed3113b9e9a873bcddb2a71faaab066110191"},
]
pytest-mock = [
{file = "pytest-mock-3.2.0.tar.gz", hash = "sha256:7122d55505d5ed5a6f3df940ad174b3f606ecae5e9bc379569cdcbd4cd9d2b83"},
@@ -2564,13 +2654,13 @@ python-dateutil = [
{file = "python_dateutil-2.8.1-py2.py3-none-any.whl", hash = "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"},
]
pytype = [
- {file = "pytype-2020.7.24-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:a977290a6c2e12c3f69987bc21a5bf16f55372535c16b59e2480c07f4576ac61"},
- {file = "pytype-2020.7.24-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:d1bc9066c922f1179e7ba2da320094f88942cf36da91654f5e9dc2968b9ba180"},
- {file = "pytype-2020.7.24-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:d2193d30730f3d6e9c1dcb284cfaa1287ec642265cf2a0e972bcb0514c3b80a9"},
- {file = "pytype-2020.7.24-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:17172f17daddcc67c281df11ffbc2f0847b1e71d799180ace79fc398b1b16315"},
- {file = "pytype-2020.7.24-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:3ca6eb0d966d30523f8249cfc37885cfb3faa169587576592e9c96dbba568a5e"},
- {file = "pytype-2020.7.24-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:49d3d6635241c7924ea671a61b47f1ec35d5e64df1f884d0c29279697cf1ca2a"},
- {file = "pytype-2020.7.24.tar.gz", hash = "sha256:9ebc2b1351621323ce8644f5d20fac190ceb8ebb911aa9f880a1a93c6844803e"},
+ {file = "pytype-2020.8.10-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:da1977a1aa74fbd237e889c1d29421d490e0be9a91a22efd96fbca2570ef9165"},
+ {file = "pytype-2020.8.10-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:e0909b99aff8eff0ece91fd64e00b935f0e4fecb51359d83d742b27db160dd00"},
+ {file = "pytype-2020.8.10-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:6b0cd56b411738eb607a299437ac405cc94208875e97ba56332105103676a903"},
+ {file = "pytype-2020.8.10-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:08662a6d426a4ef246ba36a807d526734f437451a78f83a140d338e305bc877a"},
+ {file = "pytype-2020.8.10-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:768c9ea0b08f40ce8e1ed8b9207862394d770fe3340ebebfd4210a82af530d67"},
+ {file = "pytype-2020.8.10-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:cd4399abd26a4a3498eca1ac4ad264cb407b17dfd826f9d5f14d6c06c78cf42a"},
+ {file = "pytype-2020.8.10.tar.gz", hash = "sha256:6385b6837a6db69c42eb477e8f7539c0b986ec6753eab4d811553d63d58a7785"},
]
pytz = [
{file = "pytz-2020.1-py2.py3-none-any.whl", hash = "sha256:a494d53b6d39c3c6e44c3bec237336e14305e4f29bbf800b599253057fbb79ed"},
@@ -2616,34 +2706,34 @@ pyyaml = [
{file = "PyYAML-5.3.1.tar.gz", hash = "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d"},
]
pyzmq = [
- {file = "pyzmq-19.0.1-cp27-cp27m-macosx_10_9_intel.whl", hash = "sha256:58688a2dfa044fad608a8e70ba8d019d0b872ec2acd75b7b5e37da8905605891"},
- {file = "pyzmq-19.0.1-cp27-cp27m-win32.whl", hash = "sha256:87c78f6936e2654397ca2979c1d323ee4a889eef536cc77a938c6b5be33351a7"},
- {file = "pyzmq-19.0.1-cp27-cp27m-win_amd64.whl", hash = "sha256:97b6255ae77328d0e80593681826a0479cb7bac0ba8251b4dd882f5145a2293a"},
- {file = "pyzmq-19.0.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:15b4cb21118f4589c4db8be4ac12b21c8b4d0d42b3ee435d47f686c32fe2e91f"},
- {file = "pyzmq-19.0.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:931339ac2000d12fe212e64f98ce291e81a7ec6c73b125f17cf08415b753c087"},
- {file = "pyzmq-19.0.1-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:2a88b8fabd9cc35bd59194a7723f3122166811ece8b74018147a4ed8489e6421"},
- {file = "pyzmq-19.0.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:bafd651b557dd81d89bd5f9c678872f3e7b7255c1c751b78d520df2caac80230"},
- {file = "pyzmq-19.0.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:8952f6ba6ae598e792703f3134af5a01af8f5c7cf07e9a148f05a12b02412cea"},
- {file = "pyzmq-19.0.1-cp35-cp35m-win32.whl", hash = "sha256:54aa24fd60c4262286fc64ca632f9e747c7cc3a3a1144827490e1dc9b8a3a960"},
- {file = "pyzmq-19.0.1-cp35-cp35m-win_amd64.whl", hash = "sha256:dcbc3f30c11c60d709c30a213dc56e88ac016fe76ac6768e64717bd976072566"},
- {file = "pyzmq-19.0.1-cp36-cp36m-macosx_10_9_intel.whl", hash = "sha256:6ca519309703e95d55965735a667809bbb65f52beda2fdb6312385d3e7a6d234"},
- {file = "pyzmq-19.0.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:4ee0bfd82077a3ff11c985369529b12853a4064320523f8e5079b630f9551448"},
- {file = "pyzmq-19.0.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ba6f24431b569aec674ede49cad197cad59571c12deed6ad8e3c596da8288217"},
- {file = "pyzmq-19.0.1-cp36-cp36m-win32.whl", hash = "sha256:956775444d01331c7eb412c5fb9bb62130dfaac77e09f32764ea1865234e2ca9"},
- {file = "pyzmq-19.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b08780e3a55215873b3b8e6e7ca8987f14c902a24b6ac081b344fd430d6ca7cd"},
- {file = "pyzmq-19.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:21f7d91f3536f480cb2c10d0756bfa717927090b7fb863e6323f766e5461ee1c"},
- {file = "pyzmq-19.0.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:bfff5ffff051f5aa47ba3b379d87bd051c3196b0c8a603e8b7ed68a6b4f217ec"},
- {file = "pyzmq-19.0.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:07fb8fe6826a229dada876956590135871de60dbc7de5a18c3bcce2ed1f03c98"},
- {file = "pyzmq-19.0.1-cp37-cp37m-win32.whl", hash = "sha256:342fb8a1dddc569bc361387782e8088071593e7eaf3e3ecf7d6bd4976edff112"},
- {file = "pyzmq-19.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:faee2604f279d31312bc455f3d024f160b6168b9c1dde22bf62d8c88a4deca8e"},
- {file = "pyzmq-19.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5b9d21fc56c8aacd2e6d14738021a9d64f3f69b30578a99325a728e38a349f85"},
- {file = "pyzmq-19.0.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:af0c02cf49f4f9eedf38edb4f3b6bb621d83026e7e5d76eb5526cc5333782fd6"},
- {file = "pyzmq-19.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5f1f2eb22aab606f808163eb1d537ac9a0ba4283fbeb7a62eb48d9103cf015c2"},
- {file = "pyzmq-19.0.1-cp38-cp38-win32.whl", hash = "sha256:f9d7e742fb0196992477415bb34366c12e9bb9a0699b8b3f221ff93b213d7bec"},
- {file = "pyzmq-19.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:5b99c2ae8089ef50223c28bac57510c163bfdff158c9e90764f812b94e69a0e6"},
- {file = "pyzmq-19.0.1-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:cf5d689ba9513b9753959164cf500079383bc18859f58bf8ce06d8d4bef2b054"},
- {file = "pyzmq-19.0.1-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aaa8b40b676576fd7806839a5de8e6d5d1b74981e6376d862af6c117af2a3c10"},
- {file = "pyzmq-19.0.1.tar.gz", hash = "sha256:13a5638ab24d628a6ade8f794195e1a1acd573496c3b85af2f1183603b7bf5e0"},
+ {file = "pyzmq-19.0.2-cp27-cp27m-macosx_10_9_intel.whl", hash = "sha256:59f1e54627483dcf61c663941d94c4af9bf4163aec334171686cdaee67974fe5"},
+ {file = "pyzmq-19.0.2-cp27-cp27m-win32.whl", hash = "sha256:c36ffe1e5aa35a1af6a96640d723d0d211c5f48841735c2aa8d034204e87eb87"},
+ {file = "pyzmq-19.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:0a422fc290d03958899743db091f8154958410fc76ce7ee0ceb66150f72c2c97"},
+ {file = "pyzmq-19.0.2-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:c20dd60b9428f532bc59f2ef6d3b1029a28fc790d408af82f871a7db03e722ff"},
+ {file = "pyzmq-19.0.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:d46fb17f5693244de83e434648b3dbb4f4b0fec88415d6cbab1c1452b6f2ae17"},
+ {file = "pyzmq-19.0.2-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:f1a25a61495b6f7bb986accc5b597a3541d9bd3ef0016f50be16dbb32025b302"},
+ {file = "pyzmq-19.0.2-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:ab0d01148d13854de716786ca73701012e07dff4dfbbd68c4e06d8888743526e"},
+ {file = "pyzmq-19.0.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:720d2b6083498a9281eaee3f2927486e9fe02cd16d13a844f2e95217f243efea"},
+ {file = "pyzmq-19.0.2-cp35-cp35m-win32.whl", hash = "sha256:29d51279060d0a70f551663bc592418bcad7f4be4eea7b324f6dd81de05cb4c1"},
+ {file = "pyzmq-19.0.2-cp35-cp35m-win_amd64.whl", hash = "sha256:5120c64646e75f6db20cc16b9a94203926ead5d633de9feba4f137004241221d"},
+ {file = "pyzmq-19.0.2-cp36-cp36m-macosx_10_9_intel.whl", hash = "sha256:8a6ada5a3f719bf46a04ba38595073df8d6b067316c011180102ba2a1925f5b5"},
+ {file = "pyzmq-19.0.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fa411b1d8f371d3a49d31b0789eb6da2537dadbb2aef74a43aa99a78195c3f76"},
+ {file = "pyzmq-19.0.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:00dca814469436455399660247d74045172955459c0bd49b54a540ce4d652185"},
+ {file = "pyzmq-19.0.2-cp36-cp36m-win32.whl", hash = "sha256:046b92e860914e39612e84fa760fc3f16054d268c11e0e25dcb011fb1bc6a075"},
+ {file = "pyzmq-19.0.2-cp36-cp36m-win_amd64.whl", hash = "sha256:99cc0e339a731c6a34109e5c4072aaa06d8e32c0b93dc2c2d90345dd45fa196c"},
+ {file = "pyzmq-19.0.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e36f12f503511d72d9bdfae11cadbadca22ff632ff67c1b5459f69756a029c19"},
+ {file = "pyzmq-19.0.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:c40fbb2b9933369e994b837ee72193d6a4c35dfb9a7c573257ef7ff28961272c"},
+ {file = "pyzmq-19.0.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d9fc809aa8d636e757e4ced2302569d6e60e9b9c26114a83f0d9d6519c40493"},
+ {file = "pyzmq-19.0.2-cp37-cp37m-win32.whl", hash = "sha256:3fa6debf4bf9412e59353defad1f8035a1e68b66095a94ead8f7a61ae90b2675"},
+ {file = "pyzmq-19.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:73483a2caaa0264ac717af33d6fb3f143d8379e60a422730ee8d010526ce1913"},
+ {file = "pyzmq-19.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:36ab114021c0cab1a423fe6689355e8f813979f2c750968833b318c1fa10a0fd"},
+ {file = "pyzmq-19.0.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:8b66b94fe6243d2d1d89bca336b2424399aac57932858b9a30309803ffc28112"},
+ {file = "pyzmq-19.0.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:654d3e06a4edc566b416c10293064732516cf8871a4522e0a2ba00cc2a2e600c"},
+ {file = "pyzmq-19.0.2-cp38-cp38-win32.whl", hash = "sha256:276ad604bffd70992a386a84bea34883e696a6b22e7378053e5d3227321d9702"},
+ {file = "pyzmq-19.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:09d24a80ccb8cbda1af6ed8eb26b005b6743e58e9290566d2a6841f4e31fa8e0"},
+ {file = "pyzmq-19.0.2-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:c1a31cd42905b405530e92bdb70a8a56f048c8a371728b8acf9d746ecd4482c0"},
+ {file = "pyzmq-19.0.2-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a7e7f930039ee0c4c26e4dfee015f20bd6919cd8b97c9cd7afbde2923a5167b6"},
+ {file = "pyzmq-19.0.2.tar.gz", hash = "sha256:296540a065c8c21b26d63e3cea2d1d57902373b16e4256afe46422691903a438"},
]
qtconsole = [
{file = "qtconsole-4.7.5-py2.py3-none-any.whl", hash = "sha256:4f43d0b049eacb7d723772847f0c465feccce0ccb398871a6e146001a22bad23"},
@@ -2696,8 +2786,8 @@ send2trash = [
{file = "Send2Trash-1.5.0.tar.gz", hash = "sha256:60001cc07d707fe247c94f74ca6ac0d3255aabcb930529690897ca2a39db28b2"},
]
sentry-sdk = [
- {file = "sentry-sdk-0.16.2.tar.gz", hash = "sha256:38bb09d0277117f76507c8728d9a5156f09a47ac5175bb8072513859d19a593b"},
- {file = "sentry_sdk-0.16.2-py2.py3-none-any.whl", hash = "sha256:2de15b13836fa3522815a933bd9c887c77f4868071043349f94f1b896c1bcfb8"},
+ {file = "sentry-sdk-0.16.5.tar.gz", hash = "sha256:e12eb1c2c01cd9e9cfe70608dbda4ef451f37ef0b7cbb92e5d43f87c341d6334"},
+ {file = "sentry_sdk-0.16.5-py2.py3-none-any.whl", hash = "sha256:d359609e23ec9360b61e5ffdfa417e2f6bca281bfb869608c98c169c7e64acd5"},
]
shortuuid = [
{file = "shortuuid-1.0.1-py3-none-any.whl", hash = "sha256:492c7402ff91beb1342a5898bd61ea953985bf24a41cd9f247409aa2e03c8f77"},
@@ -2716,8 +2806,8 @@ snowballstemmer = [
{file = "snowballstemmer-2.0.0.tar.gz", hash = "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"},
]
sphinx = [
- {file = "Sphinx-3.1.2-py3-none-any.whl", hash = "sha256:97dbf2e31fc5684bb805104b8ad34434ed70e6c588f6896991b2fdfd2bef8c00"},
- {file = "Sphinx-3.1.2.tar.gz", hash = "sha256:b9daeb9b39aa1ffefc2809b43604109825300300b987a24f45976c001ba1a8fd"},
+ {file = "Sphinx-3.2.1-py3-none-any.whl", hash = "sha256:ce6fd7ff5b215af39e2fcd44d4a321f6694b4530b6f2b2109b64d120773faea0"},
+ {file = "Sphinx-3.2.1.tar.gz", hash = "sha256:321d6d9b16fa381a5306e5a0b76cd48ffbc588e6340059a729c6fdd66087e0e8"},
]
sphinx-autodoc-typehints = [
{file = "sphinx-autodoc-typehints-1.11.0.tar.gz", hash = "sha256:bbf0b203f1019b0f9843ee8eef0cff856dc04b341f6dbe1113e37f2ebf243e11"},
@@ -2772,28 +2862,24 @@ toml = [
{file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"},
]
torch = [
- {file = "torch-1.5.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b84fd18fd8216b74a19828433c3beeb1f0d1d29f45dead3be9ed784ae6855966"},
- {file = "torch-1.5.1-cp35-none-macosx_10_6_x86_64.whl", hash = "sha256:5d909a55cd979fec2c9a7aa35012024b9cc106acbc496faf5de798b148406450"},
- {file = "torch-1.5.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:a358cee1d35b86757bf915e320ba776d39c20e60db50779060842efc86f02edd"},
- {file = "torch-1.5.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:0a83f41140222c7cc947aa29ed253f3e6fa490606d3d4acd02bfd9f338e3c707"},
- {file = "torch-1.5.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:70046cf66eb40ead89df25b8dcc571c3007fc9849d4e1d254cc09b4b355374d4"},
- {file = "torch-1.5.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:bb2a3e6c9c9dbfda856bd1b1a55d88789a9488b569ffba9cd6d9aa536ef866ba"},
- {file = "torch-1.5.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c42658f2982591dc4d0459645c9ab26e0ce18aa7ab0993c27c8bcb1c98931d11"},
- {file = "torch-1.5.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:ff1dbeaa017bae66036e8e7a698a5475ac5a0d7b0a690f0a04ac3b1133b1feb3"},
+ {file = "torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8"},
+ {file = "torch-1.6.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c"},
+ {file = "torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76"},
+ {file = "torch-1.6.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5"},
+ {file = "torch-1.6.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5357873e243bcfa804c32dc341f564e9a4c12addfc9baae4ee857fcc09a0a216"},
+ {file = "torch-1.6.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b"},
]
torchsummary = [
{file = "torchsummary-1.5.1-py3-none-any.whl", hash = "sha256:10f41d1743fb918f83293f13183f532ab1bb8f6639a1b89e5f8592ec1919a976"},
{file = "torchsummary-1.5.1.tar.gz", hash = "sha256:981bf689e22e0cf7f95c746002f20a24ad26aa6b9d861134a14bc6ce92230590"},
]
torchvision = [
- {file = "torchvision-0.6.1-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:eb6d7ef73ab8ed756d18c1c11ead3ba1be7b3e2fe5bf475e16a4426d7f3d6eec"},
- {file = "torchvision-0.6.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:c122386a4604a6b611626d817981777bb06e747161a4954e4037c2e297fe1d86"},
- {file = "torchvision-0.6.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9a544f241e1df7485c1fda64f06af4962ad4af1da6caeda352757ea800dd794d"},
- {file = "torchvision-0.6.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:c44819c3c903f2d6d269713a9aa81a9dcb0ab7716b4fc4dbdccf596e6fe894c9"},
- {file = "torchvision-0.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b07e7c5a7274082a8bbe88975c817da10bffaf4020017347d185b5d4eacb891d"},
- {file = "torchvision-0.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:52be3710bf495fee9eb8f1cb4d310edd2da94b127b2ca0aa77075c1001bcbb91"},
- {file = "torchvision-0.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:80a8beaa3416035c9640d153b1f0a7c9b09aac8037a032bbc33fd4bc8657b091"},
- {file = "torchvision-0.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:3af59c61d20bb0bb63dbd50f324fcd975fc5b89fb78e6b5714ae613cd7b734cc"},
+ {file = "torchvision-0.7.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a70d80bb8749c1e4a46fa56dc2fc857e98d14600841e02cc2fed766daf96c245"},
+ {file = "torchvision-0.7.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:14c0bf60fa26aabaea64ef30b8e5d441ee78d1a5eed568c30806af19bbe6b638"},
+ {file = "torchvision-0.7.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c8df7e1d1f3d4e088256be1e8c8d3eb90b016302baa4649742d47ae1531da37"},
+ {file = "torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0d1a5adfef4387659c7a0af3b72e16caa0c67224a422050ab65184d13ac9fb13"},
+ {file = "torchvision-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f5686e0a0dd511ac33eb9d6279bd34edd9f282dcb7c8ad21e290882c6206504f"},
+ {file = "torchvision-0.7.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cfa2b367bc9acf20f18b151d0525970279719e81969c17214effe77245875354"},
]
tornado = [
{file = "tornado-6.0.4-cp35-cp35m-win32.whl", hash = "sha256:5217e601700f24e966ddab689f90b7ea4bd91ff3357c3600fa1045e26d68e55d"},
@@ -2807,8 +2893,8 @@ tornado = [
{file = "tornado-6.0.4.tar.gz", hash = "sha256:0fe2d45ba43b00a41cd73f8be321a44936dc1aba233dee979f17a042b83eb6dc"},
]
tqdm = [
- {file = "tqdm-4.48.0-py2.py3-none-any.whl", hash = "sha256:fcb7cb5b729b60a27f300b15c1ffd4744f080fb483b88f31dc8654b082cc8ea5"},
- {file = "tqdm-4.48.0.tar.gz", hash = "sha256:6baa75a88582b1db6d34ce4690da5501d2a1cb65c34664840a456b2c9f794d29"},
+ {file = "tqdm-4.48.2-py2.py3-none-any.whl", hash = "sha256:1a336d2b829be50e46b84668691e0a2719f26c97c62846298dd5ae2937e4d5cf"},
+ {file = "tqdm-4.48.2.tar.gz", hash = "sha256:564d632ea2b9cb52979f7956e093e831c28d441c11751682f84c86fc46e4fd21"},
]
traitlets = [
{file = "traitlets-4.3.3-py2.py3-none-any.whl", hash = "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44"},
@@ -2856,8 +2942,8 @@ urllib3 = [
{file = "urllib3-1.25.10.tar.gz", hash = "sha256:91056c15fa70756691db97756772bb1eb9678fa585d9184f24534b100dc60f4a"},
]
wandb = [
- {file = "wandb-0.9.4-py2.py3-none-any.whl", hash = "sha256:6c2b2fff356803f38a2620737f7ba280c5bc770146c711c52a2f4d233b1a0579"},
- {file = "wandb-0.9.4.tar.gz", hash = "sha256:b3afba328455e885fd0af74fb8e40eb9fe1290d18d1e24181ab6f551f888532d"},
+ {file = "wandb-0.9.5-py2.py3-none-any.whl", hash = "sha256:f8aa54c75736aafdea600b2d3b803d4d3f999d1ac2568c253552dd557b767d9a"},
+ {file = "wandb-0.9.5.tar.gz", hash = "sha256:251ac7bcf9acc1b4ed08b1f3ea9172ed02d1e55a1a9033785e544f009208958e"},
]
watchdog = [
{file = "watchdog-0.10.3.tar.gz", hash = "sha256:4214e1379d128b0588021880ccaf40317ee156d4603ac388b9adcf29165e0c04"},
diff --git a/pyproject.toml b/pyproject.toml
index 15d1f57..49bb049 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,8 +22,8 @@ sphinx_rtd_theme = "^0.4.3"
boltons = "^20.1.0"
h5py = "^2.10.0"
toml = "^0.10.1"
-torch = "^1.5.0"
-torchvision = "^0.6.0"
+torch = "^1.6.0"
+torchvision = "^0.7.0"
torchsummary = "^1.5.1"
loguru = "^0.5.0"
matplotlib = "^3.2.1"
@@ -32,6 +32,7 @@ pytest = "^5.4.3"
opencv-python = "^4.3.0"
nltk = "^3.5"
einops = "^0.2.0"
+wandb = "^0.9.5"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 49ca4c4..3f008c3 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -2,11 +2,120 @@
"cells": [
{
"cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.nn.modules.activation.SELU"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.nn.SELU"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
- "import torch"
+ "a = \"nNone\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = a or \"relu\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'nnone'"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b.lower()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'nNone'"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b"
]
},
{
@@ -986,28 +1095,16 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 51,
"metadata": {},
- "outputs": [
- {
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-20-68e3c8bf3e1f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtqdm\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtqdm_auto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "import tqdm.auto.tqdm as tqdm_auto"
+ "import tqdm"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 52,
"metadata": {},
"outputs": [
{
@@ -1016,7 +1113,7 @@
"tqdm.notebook.tqdm_notebook"
]
},
- "execution_count": 19,
+ "execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
@@ -1027,25 +1124,50 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tqdm.auto.tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"def test():\n",
- " for i in range(9):\n",
+ " for i in tqdm.auto.tqdm(range(9)):\n",
" pass\n",
- " print(i)"
+ " print(i)\n",
+ " "
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 55,
"metadata": {},
"outputs": [
{
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e1d3b25d4ee141e882e316ec54e79d60",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
"name": "stdout",
"output_type": "stream",
"text": [
+ "\n",
"8\n"
]
}
@@ -1056,6 +1178,479 @@
},
{
"cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from time import sleep"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "41b743273ce14236bcb65782dbcd2e75",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "pbar = tqdm.auto.tqdm([\"a\", \"b\", \"c\", \"d\"], leave=True)\n",
+ "for char in pbar:\n",
+ " pbar.set_description(\"Processing %s\" % char)\n",
+ "# pbar.set_prefix()\n",
+ " sleep(0.25)\n",
+ "pbar.set_postfix({\"hej\": 0.32})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pbar.close()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cb5ad8d6109f4b1495b8fc7422bafd01",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with tqdm.auto.tqdm(total=10, bar_format=\"{postfix[0]} {postfix[1][value]:>8.2g}\",\n",
+ " postfix=[\"Batch\", dict(value=0)]) as t:\n",
+ " for i in range(10):\n",
+ " sleep(0.1)\n",
+ "# t.postfix[2][\"value\"] = 3 \n",
+ " t.postfix[1][\"value\"] = i / 2\n",
+ " t.update()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0b341d49ad074823881e84a538bcad0c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with tqdm.auto.tqdm(total=100, leave=True) as pbar:\n",
+ " for i in range(2):\n",
+ " for i in range(10):\n",
+ " sleep(0.1)\n",
+ " pbar.update(10)\n",
+ " pbar.set_postfix({\"adaf\": 23})\n",
+ " pbar.set_postfix({\"hej\": 0.32})\n",
+ " pbar.reset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "IdentityBlock(\n",
+ " (blocks): Identity()\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Identity()\n",
+ ")"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "IdentityBlock(32, 64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ResidualBlock(\n",
+ " (blocks): Identity()\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ResidualBlock(32, 64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BasicBlock(\n",
+ " (blocks): Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 32, 224, 224))\n",
+ "\n",
+ "block = BasicBlock(32, 64)\n",
+ "block(dummy).shape\n",
+ "print(block)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BottleNeckBlock(\n",
+ " (blocks): Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (3): ReLU(inplace=True)\n",
+ " (4): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 32, 10, 10))\n",
+ "\n",
+ "block = BottleNeckBlock(32, 64)\n",
+ "block(dummy).shape\n",
+ "print(block)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 128, 24, 24])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 64, 48, 48))\n",
+ "\n",
+ "layer = ResidualLayer(64, 128, block=BasicBlock, num_blocks=3)\n",
+ "layer(dummy).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(64, 128), (128, 256), (256, 512)]"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "blocks_sizes=[64, 128, 256, 512]\n",
+ "list(zip(blocks_sizes, blocks_sizes[1:]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = Encoder(depths=[1, 1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 32, 15, 15] 800\n",
+ " BatchNorm2d-2 [-1, 32, 15, 15] 64\n",
+ " ReLU-3 [-1, 32, 15, 15] 0\n",
+ " MaxPool2d-4 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-6 [-1, 32, 8, 8] 64\n",
+ " ReLU-7 [-1, 32, 8, 8] 0\n",
+ " ReLU-8 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-10 [-1, 32, 8, 8] 64\n",
+ " ReLU-11 [-1, 32, 8, 8] 0\n",
+ " ReLU-12 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-13 [-1, 32, 8, 8] 0\n",
+ " ResidualLayer-14 [-1, 32, 8, 8] 0\n",
+ " Conv2d-15 [-1, 64, 4, 4] 2,048\n",
+ " BatchNorm2d-16 [-1, 64, 4, 4] 128\n",
+ " Conv2dAuto-17 [-1, 64, 4, 4] 18,432\n",
+ " BatchNorm2d-18 [-1, 64, 4, 4] 128\n",
+ " ReLU-19 [-1, 64, 4, 4] 0\n",
+ " ReLU-20 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-21 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-22 [-1, 64, 4, 4] 128\n",
+ " ReLU-23 [-1, 64, 4, 4] 0\n",
+ " ReLU-24 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-25 [-1, 64, 4, 4] 0\n",
+ " ResidualLayer-26 [-1, 64, 4, 4] 0\n",
+ "================================================================\n",
+ "Total params: 77,152\n",
+ "Trainable params: 77,152\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.00\n",
+ "Forward/backward pass size (MB): 0.43\n",
+ "Params size (MB): 0.29\n",
+ "Estimated Total Size (MB): 0.73\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(e, (1, 28, 28), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "resnet = ResidualNetwork(1, 80, activation=\"selu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 32, 15, 15] 800\n",
+ " BatchNorm2d-2 [-1, 32, 15, 15] 64\n",
+ " SELU-3 [-1, 32, 15, 15] 0\n",
+ " MaxPool2d-4 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-6 [-1, 32, 8, 8] 64\n",
+ " SELU-7 [-1, 32, 8, 8] 0\n",
+ " SELU-8 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-10 [-1, 32, 8, 8] 64\n",
+ " SELU-11 [-1, 32, 8, 8] 0\n",
+ " SELU-12 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-13 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-14 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-15 [-1, 32, 8, 8] 64\n",
+ " SELU-16 [-1, 32, 8, 8] 0\n",
+ " SELU-17 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-18 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-19 [-1, 32, 8, 8] 64\n",
+ " SELU-20 [-1, 32, 8, 8] 0\n",
+ " SELU-21 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-22 [-1, 32, 8, 8] 0\n",
+ " ResidualLayer-23 [-1, 32, 8, 8] 0\n",
+ " Conv2d-24 [-1, 64, 4, 4] 2,048\n",
+ " BatchNorm2d-25 [-1, 64, 4, 4] 128\n",
+ " Conv2dAuto-26 [-1, 64, 4, 4] 18,432\n",
+ " BatchNorm2d-27 [-1, 64, 4, 4] 128\n",
+ " SELU-28 [-1, 64, 4, 4] 0\n",
+ " SELU-29 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-30 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-31 [-1, 64, 4, 4] 128\n",
+ " SELU-32 [-1, 64, 4, 4] 0\n",
+ " SELU-33 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-34 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-35 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-36 [-1, 64, 4, 4] 128\n",
+ " SELU-37 [-1, 64, 4, 4] 0\n",
+ " SELU-38 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-39 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-40 [-1, 64, 4, 4] 128\n",
+ " SELU-41 [-1, 64, 4, 4] 0\n",
+ " SELU-42 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-43 [-1, 64, 4, 4] 0\n",
+ " ResidualLayer-44 [-1, 64, 4, 4] 0\n",
+ " Encoder-45 [-1, 64, 4, 4] 0\n",
+ " Reduce-46 [-1, 64] 0\n",
+ " Linear-47 [-1, 80] 5,200\n",
+ " Decoder-48 [-1, 80] 0\n",
+ "================================================================\n",
+ "Total params: 174,896\n",
+ "Trainable params: 174,896\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.00\n",
+ "Forward/backward pass size (MB): 0.65\n",
+ "Params size (MB): 0.67\n",
+ "Estimated Total Size (MB): 1.32\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(resnet, (1, 28, 28), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
diff --git a/src/notebooks/01-look-at-emnist.ipynb b/src/notebooks/01-look-at-emnist.ipynb
index a68b418..8648afb 100644
--- a/src/notebooks/01-look-at-emnist.ipynb
+++ b/src/notebooks/01-look-at-emnist.ipynb
@@ -31,12 +31,31 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "dataset = EmnistDataset()\n",
- "dataset.load_emnist_dataset()"
+ "dataset = EmnistDataset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Tensor"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(dataset.data)"
]
},
{
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 96f84e5..49ebad3 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -8,6 +8,7 @@ from loguru import logger
import numpy as np
from PIL import Image
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -183,12 +184,8 @@ class EmnistDataset(Dataset):
self.input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
- # Placeholders
- self.data = None
- self.targets = None
-
# Load dataset.
- self.load_emnist_dataset()
+ self.data, self.targets = self.load_emnist_dataset()
@property
def mapper(self) -> EmnistMapper:
@@ -199,9 +196,7 @@ class EmnistDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
@@ -239,11 +234,13 @@ class EmnistDataset(Dataset):
f"Mapping: {self.mapper.mapping}\n"
)
- def _sample_to_balance(self) -> None:
+ def _sample_to_balance(
+ self, data: Tensor, targets: Tensor
+ ) -> Tuple[np.ndarray, np.ndarray]:
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(self.seed)
- x = self.data
- y = self.targets
+ x = data
+ y = targets
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
@@ -253,20 +250,22 @@ class EmnistDataset(Dataset):
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
y_sampled = y[indices]
- self.data = x_sampled
- self.targets = y_sampled
+ data = x_sampled
+ targets = y_sampled
+ return data, targets
- def _subsample(self) -> None:
+ def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""Subsamples the dataset to the specified fraction."""
- x = self.data
- y = self.targets
+ x = data
+ y = targets
num_samples = int(x.shape[0] * self.subsample_fraction)
x_sampled = x[:num_samples]
y_sampled = y[:num_samples]
self.data = x_sampled
self.targets = y_sampled
+ return data, targets
- def load_emnist_dataset(self) -> None:
+ def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:
"""Fetch the EMNIST dataset."""
dataset = EMNIST(
root=DATA_DIRNAME,
@@ -277,11 +276,13 @@ class EmnistDataset(Dataset):
target_transform=None,
)
- self.data = dataset.data
- self.targets = dataset.targets
+ data = dataset.data
+ targets = dataset.targets
if self.sample_to_balance:
- self._sample_to_balance()
+ data, targets = self._sample_to_balance(data, targets)
if self.subsample_fraction is not None:
- self._subsample()
+ data, targets = self._subsample(data, targets)
+
+ return data, targets
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index d64a991..b0617f5 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -8,6 +8,7 @@ import h5py
from loguru import logger
import numpy as np
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
Args:
- index (Union[int, torch.Tensor]): Either a list or int of indices/index.
+ index (Union[int, Tensor]): Either a list or int of indices/index.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target pair.
+ Tuple[Tensor, Tensor]: Data target pair.
"""
if torch.is_tensor(index):
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 6d40b49..74fd223 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -53,8 +53,8 @@ class Model(ABC):
"""
- # Fetch data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+ # Configure data loaders and dataset info.
+ dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
data_loader_args
)
self._input_shape = self._mapper.input_shape
@@ -70,16 +70,19 @@ class Model(ABC):
else:
self._device = device
- # Load network.
- self._network, self._network_args = self._load_network(network_fn, network_args)
+ # Configure network.
+ self._network, self._network_args = self._configure_network(
+ network_fn, network_args
+ )
# To device.
self._network.to(self._device)
- # Set training objects.
- self._criterion = self._load_criterion(criterion, criterion_args)
- self._optimizer = self._load_optimizer(optimizer, optimizer_args)
- self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
+ # Configure training objects.
+ self._criterion = self._configure_criterion(criterion, criterion_args)
+ self._optimizer, self._lr_scheduler = self._configure_optimizers(
+ optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ )
# Experiment directory.
self.model_dir = None
@@ -87,7 +90,7 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
- def _load_data_loader(
+ def _configure_data_loader(
self, data_loader_args: Optional[Dict]
) -> Tuple[str, Dict, EmnistMapper]:
"""Loads data loader, dataset name, and dataset mapper."""
@@ -102,7 +105,7 @@ class Model(ABC):
data_loaders = None
return dataset_name, data_loaders, mapper
- def _load_network(
+ def _configure_network(
self, network_fn: Type[nn.Module], network_args: Optional[Dict]
) -> Tuple[Type[nn.Module], Dict]:
"""Loads the network."""
@@ -113,7 +116,7 @@ class Model(ABC):
network = network_fn(**network_args)
return network, network_args
- def _load_criterion(
+ def _configure_criterion(
self, criterion: Optional[Callable], criterion_args: Optional[Dict]
) -> Optional[Callable]:
"""Loads the criterion."""
@@ -123,27 +126,27 @@ class Model(ABC):
_criterion = None
return _criterion
- def _load_optimizer(
- self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads the optimizer."""
+ def _configure_optimizers(
+ self,
+ optimizer: Optional[Callable],
+ optimizer_args: Optional[Dict],
+ lr_scheduler: Optional[Callable],
+ lr_scheduler_args: Optional[Dict],
+ ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ """Loads the optimizers."""
if optimizer is not None:
_optimizer = optimizer(self._network.parameters(), **optimizer_args)
else:
_optimizer = None
- return _optimizer
- def _load_lr_scheduler(
- self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads learning rate scheduler."""
if self._optimizer and lr_scheduler is not None:
if "OneCycleLR" in str(lr_scheduler):
lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
_lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
else:
_lr_scheduler = None
- return _lr_scheduler
+
+ return _optimizer, _lr_scheduler
@property
def __name__(self) -> str:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0a0ab2d..0fd7afd 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -44,6 +44,7 @@ class CharacterModel(Model):
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
+ @torch.no_grad()
def predict_on_image(
self, image: Union[np.ndarray, torch.Tensor]
) -> Tuple[str, float]:
@@ -64,10 +65,9 @@ class CharacterModel(Model):
# If the image is an unscaled tensor.
image = image.type("torch.FloatTensor") / 255
- with torch.no_grad():
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- logits = self.network(image)
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ logits = self.network(image)
prediction = self.softmax(logits.data.squeeze())
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index e6b6946..a83ca35 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,5 +1,6 @@
"""Network modules."""
from .lenet import LeNet
from .mlp import MLP
+from .residual_network import ResidualNetwork
-__all__ = ["MLP", "LeNet"]
+__all__ = ["MLP", "LeNet", "ResidualNetwork"]
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index cbc58fc..91d3f2c 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class LeNet(nn.Module):
"""LeNet network."""
@@ -16,8 +18,7 @@ class LeNet(nn.Module):
hidden_size: Tuple[int, ...] = (9216, 128),
dropout_rate: float = 0.2,
output_size: int = 10,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: Optional[str] = "relu",
) -> None:
"""The LeNet network.
@@ -28,18 +29,12 @@ class LeNet(nn.Module):
Defaults to (9216, 128).
dropout_rate (float): The dropout rate. Defaults to 0.2.
output_size (int): Number of classes. Defaults to 10.
- activation_fn (Optional[Callable]): The non-linear activation function. Defaults to
- nn.ReLU(inplace).
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
self.layers = [
nn.Conv2d(
@@ -66,7 +61,7 @@ class LeNet(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 2fbab8f..6f61b5d 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -1,9 +1,9 @@
"""Miscellaneous neural network functionality."""
-from typing import Tuple
+from typing import Tuple, Type
from einops import rearrange
import torch
-from torch.nn import Unfold
+from torch import nn
def sliding_window(
@@ -20,10 +20,24 @@ def sliding_window(
torch.Tensor: A tensor with the shape (batch, patches, height, width).
"""
- unfold = Unfold(kernel_size=patch_size, stride=stride)
+ unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the slidning window, unsqueeze as the channel dimesion is lost.
patches = unfold(images).unsqueeze(1)
patches = rearrange(
patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
)
return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+ """Returns the callable activation function."""
+ activation_fns = nn.ModuleDict(
+ [
+ ["gelu", nn.GELU()],
+ ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+ ["none", nn.Identity()],
+ ["relu", nn.ReLU(inplace=True)],
+ ["selu", nn.SELU(inplace=True)],
+ ]
+ )
+ return activation_fns[activation.lower()]
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index ac2c825..acebdaa 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class MLP(nn.Module):
"""Multi layered perceptron network."""
@@ -16,8 +18,7 @@ class MLP(nn.Module):
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: str = "relu",
) -> None:
"""Initialization of the MLP network.
@@ -27,18 +28,13 @@ class MLP(nn.Module):
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
- activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to
- None.
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
if isinstance(hidden_size, int):
hidden_size = [hidden_size] * num_layers
@@ -65,7 +61,7 @@ class MLP(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 23394b0..47e351a 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -1 +1,315 @@
"""Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.misc import activation_function
+
+
+class Conv2dAuto(nn.Conv2d):
+ """Convolution with auto padding based on kernel size."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
+
+
+def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
+ """3x3 convolution with batch norm."""
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ return nn.Sequential(
+ conv3x3(in_channels, out_channels, *args, **kwargs),
+ nn.BatchNorm2d(out_channels),
+ )
+
+
+class IdentityBlock(nn.Module):
+ """Residual with identity block."""
+
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu"
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.blocks = nn.Identity()
+ self.activation_fn = activation_function(activation)
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self.apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ x = self.activation_fn(x)
+ return x
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.out_channels
+
+
+class ResidualBlock(IdentityBlock):
+ """Residual with nonlinear shortcut."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: int = 1,
+ downsampling: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ """Short summary.
+
+ Args:
+ in_channels (int): Number of in channels.
+ out_channels (int): umber of out channels.
+ expansion (int): Expansion factor of the out channels. Defaults to 1.
+ downsampling (int): Downsampling factor used in stride. Defaults to 1.
+ *args (type): Extra arguments.
+ **kwargs (type): Extra key value arguments.
+
+ """
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.expansion = expansion
+ self.downsampling = downsampling
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ stride=self.downsampling,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.expanded_channels),
+ )
+ if self.apply_shortcut
+ else None
+ )
+
+ @property
+ def expanded_channels(self) -> int:
+ """Computes the expanded output channels."""
+ return self.out_channels * self.expansion
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.expanded_channels
+
+
+class BasicBlock(ResidualBlock):
+ """Basic ResNet block."""
+
+ expansion = 1
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ bias=False,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ bias=False,
+ ),
+ )
+
+
+class BottleNeckBlock(ResidualBlock):
+ """Bottleneck block to increase depth while minimizing parameter size."""
+
+ expansion = 4
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ ),
+ )
+
+
+class ResidualLayer(nn.Module):
+ """ResNet layer."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ block: BasicBlock = BasicBlock,
+ num_blocks: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ downsampling = 2 if in_channels != out_channels else 1
+ self.blocks = nn.Sequential(
+ block(
+ in_channels, out_channels, *args, **kwargs, downsampling=downsampling
+ ),
+ *[
+ block(
+ out_channels * block.expansion,
+ out_channels,
+ downsampling=1,
+ *args,
+ **kwargs
+ )
+ for _ in range(num_blocks - 1)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.blocks(x)
+ return x
+
+
+class Encoder(nn.Module):
+ """Encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ block_sizes: List[int] = (32, 64),
+ depths: List[int] = (2, 2),
+ activation: str = "relu",
+ block: Type[nn.Module] = BasicBlock,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ self.block_sizes = block_sizes
+ self.depths = depths
+ self.activation = activation
+
+ self.gate = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.block_sizes[0],
+ kernel_size=3,
+ stride=2,
+ padding=3,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.block_sizes[0]),
+ activation_function(self.activation),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+
+ self.blocks = self._configure_blocks(block)
+
+ def _configure_blocks(
+ self, block: Type[nn.Module], *args, **kwargs
+ ) -> nn.Sequential:
+ channels = [self.block_sizes[0]] + list(
+ zip(self.block_sizes, self.block_sizes[1:])
+ )
+ blocks = [
+ ResidualLayer(
+ in_channels=channels[0],
+ out_channels=channels[0],
+ num_blocks=self.depths[0],
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ ]
+ blocks += [
+ ResidualLayer(
+ in_channels=in_channels * block.expansion,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ for (in_channels, out_channels), num_blocks in zip(
+ channels[1:], self.depths[1:]
+ )
+ ]
+
+ return nn.Sequential(*blocks)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.gate(x)
+ return self.blocks(x)
+
+
+class Decoder(nn.Module):
+ """Classification head."""
+
+ def __init__(self, in_features: int, num_classes: int = 80) -> None:
+ super().__init__()
+ self.decoder = nn.Sequential(
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(in_features=in_features, out_features=num_classes),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.decoder(x)
+
+
+class ResidualNetwork(nn.Module):
+ """Full residual network."""
+
+ def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
+ super().__init__()
+ self.encoder = Encoder(in_channels, *args, **kwargs)
+ self.decoder = Decoder(
+ in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
+ num_classes=num_classes,
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
index 81ef9be..676eb44 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
index 49bd166..86cf103 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
new file mode 100644
index 0000000..008beb2
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
Binary files differ
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 355305c..bae02ac 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -9,25 +9,32 @@ experiments:
seed: 4711
data_loader_args:
splits: [train, val]
- batch_size: 256
shuffle: true
num_workers: 8
cuda: true
model: CharacterModel
metrics: [accuracy]
- network: MLP
+ # network: MLP
+ # network_args:
+ # input_size: 784
+ # hidden_size: 512
+ # output_size: 80
+ # num_layers: 3
+ # dropout_rate: 0
+ # activation_fn: SELU
+ network: ResidualNetwork
network_args:
- input_size: 784
- output_size: 62
- num_layers: 3
- activation_fn: GELU
+ in_channels: 1
+ num_classes: 80
+ depths: [1, 1]
+ block_sizes: [128, 256]
# network: LeNet
# network_args:
# output_size: 62
# activation_fn: GELU
train_args:
batch_size: 256
- epochs: 16
+ epochs: 32
criterion: CrossEntropyLoss
criterion_args:
weight: null
@@ -43,20 +50,24 @@ experiments:
# centered: false
optimizer: AdamW
optimizer_args:
- lr: 1.e-2
+ lr: 1.e-03
betas: [0.9, 0.999]
eps: 1.e-08
- weight_decay: 0
+ # weight_decay: 5.e-4
amsgrad: false
# lr_scheduler: null
lr_scheduler: OneCycleLR
lr_scheduler_args:
- max_lr: 1.e-3
- epochs: 16
- callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+ max_lr: 1.e-03
+ epochs: 32
+ anneal_strategy: linear
+ callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
callback_args:
Checkpoint:
monitor: val_accuracy
+ ProgressBar:
+ epochs: 32
+ log_batch_frequency: 100
EarlyStopping:
monitor: val_loss
min_delta: 0.0
@@ -68,5 +79,5 @@ experiments:
num_examples: 4
OneCycleLR:
null
- verbosity: 2 # 0, 1, 2
+ verbosity: 1 # 0, 1, 2
resume_experiment: null
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index 97c0304..4c3f9ba 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -7,11 +7,11 @@ from loguru import logger
import yaml
-# flake8: noqa: S404,S607,S603
def run_experiments(experiments_filename: str) -> None:
"""Run experiment from file."""
with open(experiments_filename) as f:
experiments_config = yaml.safe_load(f)
+
num_experiments = len(experiments_config["experiments"])
for index in range(num_experiments):
experiment_config = experiments_config["experiments"][index]
@@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None:
type=str,
help="Filename of Yaml file of experiments to run.",
)
-def main(experiments_filename: str) -> None:
+def run_cli(experiments_filename: str) -> None:
"""Parse command-line arguments and run experiments from provided file."""
run_experiments(experiments_filename)
if __name__ == "__main__":
- main()
+ run_cli()
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index d278dc2..8c063ff 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,18 +6,20 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, Tuple
+from typing import Callable, Dict, Tuple, Type
import click
from loguru import logger
import torch
from tqdm import tqdm
-from training.callbacks import CallbackList
from training.gpu_manager import GPUManager
-from training.train import Trainer
+from training.trainer.callbacks import CallbackList
+from training.trainer.train import Trainer
import wandb
import yaml
+from text_recognizer.models import Model
+
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
@@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int:
return 10
-def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
+def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment_dir = EXPERIMENTS_DIRNAME / model.__name__
@@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"""Loads all modules and arguments."""
# Import the data loader arguments.
data_loader_args = experiment_config.get("data_loader_args", {})
+ train_args = experiment_config.get("train_args", {})
+ data_loader_args["batch_size"] = train_args["batch_size"]
data_loader_args["dataset"] = experiment_config["dataset"]
data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
@@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
optimizer_args = experiment_config.get("optimizer_args", {})
# Callbacks
- callback_modules = importlib.import_module("training.callbacks")
+ callback_modules = importlib.import_module("training.trainer.callbacks")
callbacks = [
getattr(callback_modules, callback)(
**check_args(experiment_config["callback_args"][callback])
@@ -208,6 +212,7 @@ def run_experiment(
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
+ # Train the model.
trainer = Trainer(
model=model,
model_dir=model_dir,
@@ -247,7 +252,7 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
+def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
"""Run experiment."""
if gpu < 0:
gpu_manager = GPUManager(True)
@@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
if __name__ == "__main__":
- main()
+ run_cli()
diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py
new file mode 100644
index 0000000..de41bfb
--- /dev/null
+++ b/src/training/trainer/__init__.py
@@ -0,0 +1,2 @@
+"""Trainer modules."""
+from .train import Trainer
diff --git a/src/training/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index fbcc285..5942276 100644
--- a/src/training/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -2,6 +2,7 @@
from .base import Callback, CallbackList, Checkpoint
from .early_stopping import EarlyStopping
from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
+from .progress_bar import ProgressBar
from .wandb_callbacks import WandbCallback, WandbImageLogger
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"CyclicLR",
"MultiStepLR",
"OneCycleLR",
+ "ProgressBar",
"ReduceLROnPlateau",
"StepLR",
]
diff --git a/src/training/callbacks/base.py b/src/training/trainer/callbacks/base.py
index e0d91e6..8df94f3 100644
--- a/src/training/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -1,7 +1,7 @@
"""Metaclass for callback functions."""
from enum import Enum
-from typing import Callable, Dict, List, Type, Union
+from typing import Callable, Dict, List, Optional, Type, Union
from loguru import logger
import numpy as np
@@ -36,27 +36,29 @@ class Callback:
"""Called when fit ends."""
pass
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch. Only used in training mode."""
pass
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch. Only used in training mode."""
pass
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
pass
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Called at the beginning of an epoch."""
pass
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
@@ -102,18 +104,18 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks:
callback.on_epoch_begin(epoch, logs)
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
for callback in self._callbacks:
callback.on_epoch_end(epoch, logs)
def _call_batch_hook(
- self, mode: str, hook: str, batch: int, logs: Dict = {}
+ self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for all batch_{begin | end} methods."""
if hook == "begin":
@@ -123,39 +125,45 @@ class CallbackList:
else:
raise ValueError(f"Unrecognized hook {hook}.")
- def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
+ def _call_batch_begin_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Helper function for all `on_*_batch_begin` methods."""
hook_name = f"on_{mode}_batch_begin"
self._call_batch_hook_helper(hook_name, batch, logs)
- def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
+ def _call_batch_end_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Helper function for all `on_*_batch_end` methods."""
hook_name = f"on_{mode}_batch_end"
self._call_batch_hook_helper(hook_name, batch, logs)
def _call_batch_hook_helper(
- self, hook_name: str, batch: int, logs: Dict = {}
+ self, hook_name: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for `on_*_batch_begin` methods."""
for callback in self._callbacks:
hook = getattr(callback, hook_name)
hook(batch, logs)
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch)
+ self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "end", batch)
+ self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch)
+ self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch)
+ self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
def __iter__(self) -> iter:
"""Iter function for callback list."""
diff --git a/src/training/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py
index c9b7907..02b431f 100644
--- a/src/training/callbacks/early_stopping.py
+++ b/src/training/trainer/callbacks/early_stopping.py
@@ -4,7 +4,8 @@ from typing import Dict, Union
from loguru import logger
import numpy as np
import torch
-from training.callbacks import Callback
+from torch import Tensor
+from training.trainer.callbacks import Callback
class EarlyStopping(Callback):
@@ -95,7 +96,7 @@ class EarlyStopping(Callback):
f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
)
- def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]:
+ def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
"""Extracts the monitor value."""
monitor_value = logs.get(self.monitor)
if monitor_value is None:
diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 00c7e9b..ba2226a 100644
--- a/src/training/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -1,7 +1,7 @@
"""Callbacks for learning rate schedulers."""
from typing import Callable, Dict, List, Optional, Type
-from training.callbacks import Callback
+from training.trainer.callbacks import Callback
from text_recognizer.models import Model
@@ -19,7 +19,7 @@ class StepLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
self.lr_scheduler.step()
@@ -37,7 +37,7 @@ class MultiStepLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
self.lr_scheduler.step()
@@ -55,7 +55,7 @@ class ReduceLROnPlateau(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
val_loss = logs["val_loss"]
self.lr_scheduler.step(val_loss)
@@ -74,7 +74,7 @@ class CyclicLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
@@ -92,6 +92,6 @@ class OneCycleLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..1970747
--- /dev/null
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,61 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+ """A TQDM progress bar for the training loop."""
+
+ def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+ """Initializes the tqdm callback."""
+ self.epochs = epochs
+ self.log_batch_frequency = log_batch_frequency
+ self.progress_bar = None
+ self.val_metrics = {}
+
+ def _configure_progress_bar(self) -> None:
+ """Configures the tqdm progress bar with custom bar format."""
+ self.progress_bar = tqdm(
+ total=len(self.model.data_loaders["train"]),
+ leave=True,
+ unit="step",
+ mininterval=self.log_batch_frequency,
+ bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ )
+
+ def _key_abbreviations(self, logs: Dict) -> Dict:
+ """Changes the length of keys, so that the progress bar fits better."""
+
+ def rename(key: str) -> str:
+ """Renames accuracy to acc."""
+ return key.replace("accuracy", "acc")
+
+ return {rename(key): value for key, value in logs.items()}
+
+ def on_fit_begin(self) -> None:
+ """Creates a tqdm progress bar."""
+ self._configure_progress_bar()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+ """Updates the description with the current epoch."""
+ self.progress_bar.reset()
+ self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """At the end of each epoch, the validation metrics are updated to the progress bar."""
+ self.val_metrics = logs
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+ """Updates the progress bar for each training step."""
+ if self.val_metrics:
+ logs.update(self.val_metrics)
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_fit_end(self) -> None:
+ """Closes the tqdm progress bar."""
+ self.progress_bar.close()
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 6ada6df..e44c745 100644
--- a/src/training/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -1,9 +1,9 @@
-"""Callbacks using wandb."""
+"""Callback for W&B."""
from typing import Callable, Dict, List, Optional, Type
import numpy as np
from torchvision.transforms import Compose, ToTensor
-from training.callbacks import Callback
+from training.trainer.callbacks import Callback
import wandb
from text_recognizer.datasets import Transpose
@@ -28,12 +28,12 @@ class WandbCallback(Callback):
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
wandb.log(logs, commit=True)
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs training metrics."""
if logs is not None:
self._on_batch_end(batch, logs)
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs validation metrics."""
if logs is not None:
self._on_batch_end(batch, logs)
diff --git a/src/training/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
index 868d739..868d739 100644
--- a/src/training/population_based_training/__init__.py
+++ b/src/training/trainer/population_based_training/__init__.py
diff --git a/src/training/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
index 868d739..868d739 100644
--- a/src/training/population_based_training/population_based_training.py
+++ b/src/training/trainer/population_based_training/population_based_training.py
diff --git a/src/training/train.py b/src/training/trainer/train.py
index aaa0430..a75ae8f 100644
--- a/src/training/train.py
+++ b/src/training/trainer/train.py
@@ -7,9 +7,9 @@ from typing import Dict, List, Optional, Tuple, Type
from loguru import logger
import numpy as np
import torch
-from tqdm import tqdm, trange
-from training.callbacks import Callback, CallbackList
-from training.util import RunningAverage
+from torch import Tensor
+from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.util import RunningAverage
import wandb
from text_recognizer.models import Model
@@ -46,11 +46,11 @@ class Trainer:
self.model_dir = model_dir
self.checkpoint_path = checkpoint_path
self.start_epoch = 1
- self.epochs = train_args["epochs"] + self.start_epoch
+ self.epochs = train_args["epochs"]
self.callbacks = callbacks
if self.checkpoint_path is not None:
- self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1
+ self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
# Parse the name of the experiment.
experiment_dir = str(self.model_dir.parents[1]).split("/")
@@ -59,7 +59,7 @@ class Trainer:
def training_step(
self,
batch: int,
- samples: Tuple[torch.Tensor, torch.Tensor],
+ samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
) -> Dict:
"""Performs the training step."""
@@ -108,27 +108,16 @@ class Trainer:
data_loader = self.model.data_loaders["train"]
- with tqdm(
- total=len(data_loader),
- leave=False,
- unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
- ) as t:
- for batch, samples in enumerate(data_loader):
- self.callbacks.on_train_batch_begin(batch)
-
- metrics = self.training_step(batch, samples, loss_avg)
-
- self.callbacks.on_train_batch_end(batch, logs=metrics)
-
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_train_batch_begin(batch)
+ metrics = self.training_step(batch, samples, loss_avg)
+ self.callbacks.on_train_batch_end(batch, logs=metrics)
+ @torch.no_grad()
def validation_step(
self,
batch: int,
- samples: Tuple[torch.Tensor, torch.Tensor],
+ samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
) -> Dict:
"""Performs the validation step."""
@@ -158,6 +147,12 @@ class Trainer:
return metrics
+ def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(
+ log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
+ )
+
def validate(self, epoch: Optional[int] = None) -> Dict:
"""Runs the validation loop for one epoch."""
# Set model to eval mode.
@@ -172,41 +167,18 @@ class Trainer:
# Summary for the current eval loop.
summary = []
- with tqdm(
- total=len(data_loader),
- leave=False,
- unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
- ) as t:
- with torch.no_grad():
- for batch, samples in enumerate(data_loader):
- self.callbacks.on_validation_batch_begin(batch)
-
- metrics = self.validation_step(batch, samples, loss_avg)
-
- self.callbacks.on_validation_batch_end(batch, logs=metrics)
-
- summary.append(metrics)
-
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_validation_batch_begin(batch)
+ metrics = self.validation_step(batch, samples, loss_avg)
+ self.callbacks.on_validation_batch_end(batch, logs=metrics)
+ summary.append(metrics)
# Compute mean of all metrics.
metrics_mean = {
"val_" + metric: np.mean([x[metric] for x in summary])
for metric in summary[0]
}
- if epoch:
- logger.debug(
- f"Validation metrics at epoch {epoch} - "
- + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
- else:
- logger.debug(
- "Validation metrics - "
- + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
+ self._log_val_metric(metrics_mean, epoch)
return metrics_mean
@@ -214,19 +186,14 @@ class Trainer:
"""Runs the training and evaluation loop."""
logger.debug(f"Running an experiment called {self.experiment_name}.")
+
+ # Set start time.
t_start = time.time()
self.callbacks.on_fit_begin()
- # TODO: fix progress bar as callback.
# Run the training loop.
- for epoch in trange(
- self.start_epoch,
- self.epochs,
- leave=False,
- bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}",
- desc="Epoch",
- ):
+ for epoch in range(self.start_epoch, self.epochs + 1):
self.callbacks.on_epoch_begin(epoch)
# Perform one training pass over the training set.
diff --git a/src/training/util.py b/src/training/trainer/util.py
index 132b2dc..132b2dc 100644
--- a/src/training/util.py
+++ b/src/training/trainer/util.py