From 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 24 Feb 2021 22:00:29 +0100 Subject: updates --- .flake8 | 1 + .pre-commit-config.yaml | 4 +- .pytype/imports/default.pyi | 2 +- .vim/coc-settings.json | 6 + noxfile.py | 13 +- poetry.lock | 1123 ++++++++-------- src/notebooks/00-testing-stuff-out.ipynb | 1361 +++----------------- src/notebooks/07-try-gtn.ipynb | 49 +- src/notebooks/Untitled.ipynb | 385 ++++++ src/notebooks/intersection.pdf | Bin 0 -> 10154 bytes src/tasks/build_transitions.py | 6 +- src/tasks/make_wordpieces.py | 2 +- src/text_recognizer/datasets/iam_preprocessor.py | 2 +- src/text_recognizer/datasets/transforms.py | 119 ++ src/text_recognizer/networks/__init__.py | 3 +- src/text_recognizer/networks/cnn_transformer.py | 4 +- .../networks/transducer/__init__.py | 1 + .../networks/transducer/tds_conv.py | 15 +- src/text_recognizer/networks/transducer/test.py | 60 + .../networks/transducer/transducer.py | 410 ++++++ 20 files changed, 1821 insertions(+), 1745 deletions(-) create mode 100644 .vim/coc-settings.json create mode 100644 src/notebooks/Untitled.ipynb create mode 100644 src/notebooks/intersection.pdf create mode 100644 src/text_recognizer/networks/transducer/test.py create mode 100644 src/text_recognizer/networks/transducer/transducer.py diff --git a/.flake8 b/.flake8 index eff48a6..b00f63b 100644 --- a/.flake8 +++ b/.flake8 @@ -7,3 +7,4 @@ application-import-names = text_recognizer,tests import-order-style = google docstring-convention = google per-file-ignores = tests/*:S101,tests/*:S106,src/text_recognizer/datasets/*:S110,src/training/callbacks/*:B006,src/tasks/build_transitions.py:C901 +exclude = src/text_recognizer/networks/transducer/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea3565b..8fed8e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,11 @@ repos: hooks: - id: black name: black - entry: poetry run black + entry: black language: system types: [python] - id: flake8 name: flake8 - entry: poetry run flake8 + entry: flake8 language: system types: [python] diff --git a/.pytype/imports/default.pyi b/.pytype/imports/default.pyi index 7bb11b4..024e962 100644 --- a/.pytype/imports/default.pyi +++ b/.pytype/imports/default.pyi @@ -1,3 +1,3 @@ - from typing import Any + def __getattr__(name) -> Any: ... diff --git a/.vim/coc-settings.json b/.vim/coc-settings.json new file mode 100644 index 0000000..ce08b20 --- /dev/null +++ b/.vim/coc-settings.json @@ -0,0 +1,6 @@ +{ + "python.linting.pylintEnabled": false, + "python.linting.flake8Enabled": true, + "python.linting.enabled": true, + "python.formatting.provider": "black" +} diff --git a/noxfile.py b/noxfile.py index d1a8d1b..60c3923 100644 --- a/noxfile.py +++ b/noxfile.py @@ -33,6 +33,7 @@ def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> Non "export", "--dev", "--format=requirements.txt", + "--without-hashes", f"--output={requirements.name}", external=True, ) @@ -47,7 +48,7 @@ def black(session: Session) -> None: session.run("black", *args) -@nox.session(python=["3.8", "3.7"]) +@nox.session(python=["3.8"]) def lint(session: Session) -> None: """Lint using flake8.""" args = session.posargs or locations @@ -82,7 +83,7 @@ def safety(session: Session) -> None: session.run("safety", "check", f"--file={requirements.name}", "--full-report") -@nox.session(python=["3.8", "3.7"]) +@nox.session(python=["3.8"]) def mypy(session: Session) -> None: """Type-check using mypy.""" args = session.posargs or locations @@ -90,7 +91,7 @@ def mypy(session: Session) -> None: session.run("mypy", *args) -@nox.session(python="3.7") +@nox.session(python="3.8") def pytype(session: Session) -> None: """Type-check using pytype.""" args = session.posargs or ["--disable=import-error", *locations] @@ -98,7 +99,7 @@ def pytype(session: Session) -> None: session.run("pytype", *args) -@nox.session(python=["3.8", "3.7"]) +@nox.session(python=["3.8"]) def tests(session: Session) -> None: """Run the test suite.""" args = session.posargs or ["--cov", "-m", "not e2e"] @@ -109,7 +110,7 @@ def tests(session: Session) -> None: session.run("pytest", *args) -@nox.session(python=["3.8", "3.7"]) +@nox.session(python=["3.8"]) def typeguard(session: Session) -> None: """Runtime type checking using Typeguard.""" args = session.posargs or ["-m", "not e2e"] @@ -118,7 +119,7 @@ def typeguard(session: Session) -> None: session.run("pytest", f"--typeguard-packages={package}", *args) -@nox.session(python=["3.8", "3.7"]) +@nox.session(python=["3.8"]) def xdoctest(session: Session) -> None: """Run examples with xdoctest.""" args = session.posargs or ["all"] diff --git a/poetry.lock b/poetry.lock index 7f715d8..72da168 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,117 +1,115 @@ [[package]] -category = "main" -description = "A configurable sidebar-enabled Sphinx theme" name = "alabaster" +version = "0.7.12" +description = "A configurable sidebar-enabled Sphinx theme" +category = "main" optional = false python-versions = "*" -version = "0.7.12" [[package]] -category = "dev" -description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = "*" -version = "1.4.4" [[package]] -category = "main" -description = "Disable App Nap on OS X 10.9" -marker = "sys_platform == \"darwin\" or platform_system == \"Darwin\" or python_version >= \"3.3\" and sys_platform == \"darwin\"" name = "appnope" +version = "0.1.0" +description = "Disable App Nap on OS X 10.9" +category = "main" optional = false python-versions = "*" -version = "0.1.0" [[package]] -category = "main" -description = "The secure Argon2 password hashing algorithm." name = "argon2-cffi" +version = "20.1.0" +description = "The secure Argon2 password hashing algorithm." +category = "main" optional = false python-versions = "*" -version = "20.1.0" [package.dependencies] cffi = ">=1.0.0" six = "*" [package.extras] -dev = ["coverage (>=5.0.2)", "hypothesis", "pytest", "sphinx", "wheel", "pre-commit"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest", "sphinx", "wheel", "pre-commit"] docs = ["sphinx"] -tests = ["coverage (>=5.0.2)", "hypothesis", "pytest"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] [[package]] -category = "main" -description = "Async generators and context managers for Python 3.5+" name = "async-generator" +version = "1.10" +description = "Async generators and context managers for Python 3.5+" +category = "main" optional = false python-versions = ">=3.5" -version = "1.10" [[package]] -category = "main" -description = "Atomic file writes." -marker = "sys_platform == \"win32\"" name = "atomicwrites" +version = "1.4.0" +description = "Atomic file writes." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.4.0" [[package]] -category = "main" -description = "Classes Without Boilerplate" name = "attrs" +version = "20.3.0" +description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "20.3.0" [package.extras] -dev = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] docs = ["furo", "sphinx", "zope.interface"] -tests = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] -tests_no_zope = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] [[package]] -category = "main" -description = "Internationalization utilities" name = "babel" +version = "2.9.0" +description = "Internationalization utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "2.9.0" [package.dependencies] pytz = ">=2015.7" [[package]] -category = "main" -description = "Specifications for callback functions passed in to an API" name = "backcall" +version = "0.2.0" +description = "Specifications for callback functions passed in to an API" +category = "main" optional = false python-versions = "*" -version = "0.2.0" [[package]] -category = "dev" -description = "Security oriented static analyser for python code." name = "bandit" +version = "1.6.2" +description = "Security oriented static analyser for python code." +category = "dev" optional = false python-versions = "*" -version = "1.6.2" [package.dependencies] +colorama = {version = ">=0.3.9", markers = "platform_system == \"Windows\""} GitPython = ">=1.0.1" PyYAML = ">=3.13" -colorama = ">=0.3.9" six = ">=1.10.0" stevedore = ">=1.20.0" [[package]] -category = "dev" -description = "The uncompromising code formatter." name = "black" +version = "19.10b0" +description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.6" -version = "19.10b0" [package.dependencies] appdirs = "*" @@ -126,12 +124,12 @@ typed-ast = ">=1.4.0" d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] [[package]] -category = "main" -description = "An easy safelist-based HTML-sanitizing tool." name = "bleach" +version = "3.2.1" +description = "An easy safelist-based HTML-sanitizing tool." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "3.2.1" [package.dependencies] packaging = "*" @@ -139,146 +137,143 @@ six = ">=1.9.0" webencodings = "*" [[package]] -category = "dev" -description = "A thin, practical wrapper around terminal coloring, styling, and positioning" name = "blessings" +version = "1.7" +description = "A thin, practical wrapper around terminal coloring, styling, and positioning" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.7" [package.dependencies] six = "*" [[package]] -category = "main" -description = "When they're not builtins, they're boltons." name = "boltons" +version = "20.2.1" +description = "When they're not builtins, they're boltons." +category = "main" optional = false python-versions = "*" -version = "20.2.1" [[package]] -category = "main" -description = "Python package for providing Mozilla's CA Bundle." name = "certifi" +version = "2020.11.8" +description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = "*" -version = "2020.11.8" [[package]] -category = "main" -description = "Foreign Function Interface for Python calling C code." name = "cffi" +version = "1.14.3" +description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" -version = "1.14.3" [package.dependencies] pycparser = "*" [[package]] -category = "main" -description = "Universal encoding detector for Python 2 and 3" name = "chardet" +version = "3.0.4" +description = "Universal encoding detector for Python 2 and 3" +category = "main" optional = false python-versions = "*" -version = "3.0.4" [[package]] -category = "main" -description = "Composable command line interface toolkit" name = "click" +version = "7.1.2" +description = "Composable command line interface toolkit" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "7.1.2" [[package]] -category = "main" -description = "Cross-platform colored terminal text." -marker = "sys_platform == \"win32\" or platform_system == \"Windows\" or python_version >= \"3.3\" and sys_platform == \"win32\"" name = "colorama" +version = "0.4.4" +description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.4.4" [[package]] -category = "main" -description = "Updated configparser from Python 3.8 for Python 2.6+." name = "configparser" +version = "5.0.1" +description = "Updated configparser from Python 3.8 for Python 2.6+." +category = "main" optional = false python-versions = ">=3.6" -version = "5.0.1" [package.extras] docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] -testing = ["pytest (>=3.5,<3.7.3 || >3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "jaraco.test (>=3.2.0)", "pytest-black (>=0.3.7)", "pytest-mypy"] +testing = ["pytest (>=3.5,!=3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "jaraco.test (>=3.2.0)", "pytest-black (>=0.3.7)", "pytest-mypy"] [[package]] -category = "dev" -description = "Code coverage measurement for Python" name = "coverage" +version = "5.3" +description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" -version = "5.3" [package.dependencies] -[package.dependencies.toml] -optional = true -version = "*" +toml = {version = "*", optional = true, markers = "extra == \"toml\""} [package.extras] toml = ["toml"] [[package]] -category = "main" -description = "Composable style cycles" name = "cycler" +version = "0.10.0" +description = "Composable style cycles" +category = "main" optional = false python-versions = "*" -version = "0.10.0" [package.dependencies] six = "*" [[package]] -category = "main" -description = "A utility for ensuring Google-style docstrings stay up to date with the source code." name = "darglint" +version = "1.5.6" +description = "A utility for ensuring Google-style docstrings stay up to date with the source code." +category = "main" optional = false python-versions = ">=3.6,<4.0" -version = "1.5.6" [[package]] -category = "main" -description = "A backport of the dataclasses module for Python 3.6" name = "dataclasses" +version = "0.6" +description = "A backport of the dataclasses module for Python 3.6" +category = "main" optional = false python-versions = "*" -version = "0.6" [[package]] -category = "main" -description = "Decorators for Humans" name = "decorator" +version = "4.4.2" +description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*" -version = "4.4.2" [[package]] -category = "main" -description = "XML bomb protection for Python stdlib modules" name = "defusedxml" +version = "0.6.0" +description = "XML bomb protection for Python stdlib modules" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.6.0" [[package]] -category = "main" -description = "Deserialize to objects while staying DRY" name = "desert" +version = "2020.11.18" +description = "Deserialize to objects while staying DRY" +category = "main" optional = false python-versions = ">=3.6" -version = "2020.11.18" [package.dependencies] attrs = "*" @@ -290,31 +285,31 @@ dev = ["coverage", "cuvner", "marshmallow-enum", "marshmallow-union", "pytest", test = ["coverage", "cuvner", "marshmallow-enum", "marshmallow-union", "pytest", "pytest-cov", "pytest-sphinx", "pytest-travis-fold", "tox", "importlib-metadata"] [[package]] -category = "main" -description = "Python bindings for the docker credentials store API" name = "docker-pycreds" +version = "0.4.0" +description = "Python bindings for the docker credentials store API" +category = "main" optional = false python-versions = "*" -version = "0.4.0" [package.dependencies] six = ">=1.4.0" [[package]] -category = "main" -description = "Docutils -- Python Documentation Utilities" name = "docutils" +version = "0.16" +description = "Docutils -- Python Documentation Utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.16" [[package]] -category = "dev" -description = "A parser for Python dependency files" name = "dparse" +version = "0.5.1" +description = "A parser for Python dependency files" +category = "dev" optional = false python-versions = ">=3.5" -version = "0.5.1" [package.dependencies] packaging = "*" @@ -325,28 +320,28 @@ toml = "*" pipenv = ["pipenv"] [[package]] -category = "main" -description = "A new flavour of deep learning operations" name = "einops" +version = "0.3.0" +description = "A new flavour of deep learning operations" +category = "main" optional = false python-versions = "*" -version = "0.3.0" [[package]] -category = "main" -description = "Discover and load entry points from installed packages." name = "entrypoints" +version = "0.3" +description = "Discover and load entry points from installed packages." +category = "main" optional = false python-versions = ">=2.7" -version = "0.3" [[package]] -category = "main" -description = "the modular source code checker: pep8 pyflakes and co" name = "flake8" +version = "3.8.4" +description = "the modular source code checker: pep8 pyflakes and co" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" -version = "3.8.4" [package.dependencies] mccabe = ">=0.6.0,<0.7.0" @@ -354,23 +349,23 @@ pycodestyle = ">=2.6.0a1,<2.7.0" pyflakes = ">=2.2.0,<2.3.0" [[package]] -category = "main" -description = "Flake8 Type Annotation Checks" name = "flake8-annotations" +version = "2.4.1" +description = "Flake8 Type Annotation Checks" +category = "main" optional = false python-versions = ">=3.6.1,<4.0.0" -version = "2.4.1" [package.dependencies] flake8 = ">=3.7,<3.9" [[package]] -category = "dev" -description = "Automated security testing with bandit and flake8." name = "flake8-bandit" +version = "2.1.2" +description = "Automated security testing with bandit and flake8." +category = "dev" optional = false python-versions = "*" -version = "2.1.2" [package.dependencies] bandit = "*" @@ -379,101 +374,100 @@ flake8-polyfill = "*" pycodestyle = "*" [[package]] -category = "dev" -description = "flake8 plugin to call black as a code style validator" name = "flake8-black" +version = "0.2.1" +description = "flake8 plugin to call black as a code style validator" +category = "dev" optional = false python-versions = "*" -version = "0.2.1" [package.dependencies] black = "*" flake8 = ">=3.0.0" [[package]] -category = "dev" -description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." name = "flake8-bugbear" +version = "20.1.4" +description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." +category = "dev" optional = false python-versions = ">=3.6" -version = "20.1.4" [package.dependencies] attrs = ">=19.2.0" flake8 = ">=3.0.0" [[package]] -category = "main" -description = "Extension for flake8 which uses pydocstyle to check docstrings" name = "flake8-docstrings" +version = "1.5.0" +description = "Extension for flake8 which uses pydocstyle to check docstrings" +category = "main" optional = false python-versions = "*" -version = "1.5.0" [package.dependencies] flake8 = ">=3" pydocstyle = ">=2.1" [[package]] -category = "dev" -description = "Flake8 and pylama plugin that checks the ordering of import statements." name = "flake8-import-order" +version = "0.18.1" +description = "Flake8 and pylama plugin that checks the ordering of import statements." +category = "dev" optional = false python-versions = "*" -version = "0.18.1" [package.dependencies] pycodestyle = "*" -setuptools = "*" [[package]] -category = "dev" -description = "Polyfill package for Flake8 plugins" name = "flake8-polyfill" +version = "1.0.2" +description = "Polyfill package for Flake8 plugins" +category = "dev" optional = false python-versions = "*" -version = "1.0.2" [package.dependencies] flake8 = "*" [[package]] -category = "main" -description = "Clean single-source support for Python 3 and 2" name = "future" +version = "0.18.2" +description = "Clean single-source support for Python 3 and 2" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -version = "0.18.2" [[package]] -category = "main" -description = "Git Object Database" name = "gitdb" +version = "4.0.5" +description = "Git Object Database" +category = "main" optional = false python-versions = ">=3.4" -version = "4.0.5" [package.dependencies] smmap = ">=3.0.1,<4" [[package]] -category = "main" -description = "Python Git Library" name = "gitpython" -optional = false -python-versions = ">=3.4" version = "3.1.11" +description = "Python Git Library" +category = "main" +optional = false +python-versions = ">=3.4" [package.dependencies] gitdb = ">=4.0.1,<5" [[package]] -category = "dev" -description = "An utility to monitor NVIDIA GPU status and usage" name = "gpustat" +version = "0.6.0" +description = "An utility to monitor NVIDIA GPU status and usage" +category = "dev" optional = false python-versions = "*" -version = "0.6.0" [package.dependencies] blessings = ">=1.6" @@ -485,12 +479,12 @@ six = ">=1.7" test = ["mock (>=2.0.0)", "pytest (<5.0)"] [[package]] -category = "dev" -description = "Simple Python interface for Graphviz" name = "graphviz" +version = "0.16" +description = "Simple Python interface for Graphviz" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" -version = "0.16" [package.extras] dev = ["tox (>=3)", "flake8", "pep8-naming", "wheel", "twine"] @@ -498,51 +492,51 @@ docs = ["sphinx (>=1.8)", "sphinx-rtd-theme"] test = ["mock (>=3)", "pytest (>=4)", "pytest-mock (>=2)", "pytest-cov"] [[package]] -category = "main" -description = "Automatic differentiation with WFSTs" name = "gtn" +version = "0.0.0" +description = "Automatic differentiation with WFSTs" +category = "main" optional = false python-versions = ">=3.5" -version = "0.0.0" [[package]] -category = "main" -description = "Read and write HDF5 files from Python" name = "h5py" +version = "2.10.0" +description = "Read and write HDF5 files from Python" +category = "main" optional = false python-versions = "*" -version = "2.10.0" [package.dependencies] numpy = ">=1.7" six = "*" [[package]] -category = "main" -description = "Internationalized Domain Names in Applications (IDNA)" name = "idna" +version = "2.10" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "2.10" [[package]] -category = "main" -description = "Getting image size from png/jpeg/jpeg2000/gif file" name = "imagesize" +version = "1.2.0" +description = "Getting image size from png/jpeg/jpeg2000/gif file" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.2.0" [[package]] -category = "main" -description = "IPython Kernel for Jupyter" name = "ipykernel" +version = "5.3.4" +description = "IPython Kernel for Jupyter" +category = "main" optional = false python-versions = ">=3.5" -version = "5.3.4" [package.dependencies] -appnope = "*" +appnope = {version = "*", markers = "platform_system == \"Darwin\""} ipython = ">=5.0.0" jupyter-client = "*" tornado = ">=4.2" @@ -552,24 +546,23 @@ traitlets = ">=4.1.0" test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "nose"] [[package]] -category = "main" -description = "IPython: Productive Interactive Computing" name = "ipython" +version = "7.19.0" +description = "IPython: Productive Interactive Computing" +category = "main" optional = false python-versions = ">=3.7" -version = "7.19.0" [package.dependencies] -appnope = "*" +appnope = {version = "*", markers = "sys_platform == \"darwin\""} backcall = "*" -colorama = "*" +colorama = {version = "*", markers = "sys_platform == \"win32\""} decorator = "*" jedi = ">=0.10" -pexpect = ">4.3" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} pickleshare = "*" prompt-toolkit = ">=2.0.0,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.1.0" pygments = "*" -setuptools = ">=18.5" traitlets = ">=4.2" [package.extras] @@ -584,56 +577,53 @@ qtconsole = ["qtconsole"] test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.14)"] [[package]] -category = "main" -description = "Vestigial utilities from IPython" name = "ipython-genutils" +version = "0.2.0" +description = "Vestigial utilities from IPython" +category = "main" optional = false python-versions = "*" -version = "0.2.0" [[package]] -category = "dev" -description = "IPython HTML widgets for Jupyter" name = "ipywidgets" +version = "7.5.1" +description = "IPython HTML widgets for Jupyter" +category = "dev" optional = false python-versions = "*" -version = "7.5.1" [package.dependencies] ipykernel = ">=4.5.1" +ipython = {version = ">=4.0.0", markers = "python_version >= \"3.3\""} nbformat = ">=4.2.0" traitlets = ">=4.3.1" widgetsnbextension = ">=3.5.0,<3.6.0" -[package.dependencies.ipython] -python = ">=3.3" -version = ">=4.0.0" - [package.extras] test = ["pytest (>=3.6.0)", "pytest-cov", "mock"] [[package]] -category = "main" -description = "An autocompletion tool for Python that can be used for text editors." name = "jedi" +version = "0.17.2" +description = "An autocompletion tool for Python that can be used for text editors." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.17.2" [package.dependencies] parso = ">=0.7.0,<0.8.0" [package.extras] -qa = ["flake8 (3.7.9)"] +qa = ["flake8 (==3.7.9)"] testing = ["Django (<3.1)", "colorama", "docopt", "pytest (>=3.9.0,<5.0.0)"] [[package]] -category = "main" -description = "A very fast and expressive template engine." name = "jinja2" +version = "2.11.2" +description = "A very fast and expressive template engine." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "2.11.2" [package.dependencies] MarkupSafe = ">=0.23" @@ -642,25 +632,24 @@ MarkupSafe = ">=0.23" i18n = ["Babel (>=0.8)"] [[package]] -category = "main" -description = "Lightweight pipelining: using Python functions as pipeline jobs." name = "joblib" +version = "0.17.0" +description = "Lightweight pipelining: using Python functions as pipeline jobs." +category = "main" optional = false python-versions = ">=3.6" -version = "0.17.0" [[package]] -category = "main" -description = "An implementation of JSON Schema validation for Python" name = "jsonschema" +version = "3.2.0" +description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = "*" -version = "3.2.0" [package.dependencies] attrs = ">=17.4.0" pyrsistent = ">=0.14.0" -setuptools = "*" six = ">=1.11.0" [package.extras] @@ -668,12 +657,12 @@ format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors format_nongpl = ["idna", "jsonpointer (>1.13)", "webcolors", "rfc3986-validator (>0.1.0)", "rfc3339-validator"] [[package]] -category = "dev" -description = "Jupyter metapackage. Install all the Jupyter components in one go." name = "jupyter" +version = "1.0.0" +description = "Jupyter metapackage. Install all the Jupyter components in one go." +category = "dev" optional = false python-versions = "*" -version = "1.0.0" [package.dependencies] ipykernel = "*" @@ -684,12 +673,12 @@ notebook = "*" qtconsole = "*" [[package]] -category = "main" -description = "Jupyter protocol implementation and client libraries" name = "jupyter-client" +version = "6.1.7" +description = "Jupyter protocol implementation and client libraries" +category = "main" optional = false python-versions = ">=3.5" -version = "6.1.7" [package.dependencies] jupyter-core = ">=4.6.0" @@ -702,12 +691,12 @@ traitlets = "*" test = ["ipykernel", "ipython", "mock", "pytest", "pytest-asyncio", "async-generator", "pytest-timeout"] [[package]] -category = "dev" -description = "Jupyter terminal console" name = "jupyter-console" +version = "6.2.0" +description = "Jupyter terminal console" +category = "dev" optional = false python-versions = ">=3.6" -version = "6.2.0" [package.dependencies] ipykernel = "*" @@ -720,35 +709,35 @@ pygments = "*" test = ["pexpect"] [[package]] -category = "main" -description = "Jupyter core package. A base package on which Jupyter projects rely." name = "jupyter-core" +version = "4.7.0" +description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "main" optional = false python-versions = ">=3.6" -version = "4.7.0" [package.dependencies] -pywin32 = ">=1.0" +pywin32 = {version = ">=1.0", markers = "sys_platform == \"win32\""} traitlets = "*" [[package]] -category = "main" -description = "Pygments theme using JupyterLab CSS variables" name = "jupyterlab-pygments" +version = "0.1.2" +description = "Pygments theme using JupyterLab CSS variables" +category = "main" optional = false python-versions = "*" -version = "0.1.2" [package.dependencies] pygments = ">=2.4.1,<3" [[package]] -category = "main" -description = "Select and install a Jupyter notebook theme" name = "jupyterthemes" +version = "0.20.0" +description = "Select and install a Jupyter notebook theme" +category = "main" optional = false python-versions = "*" -version = "0.20.0" [package.dependencies] ipython = ">=5.4.1" @@ -758,69 +747,69 @@ matplotlib = ">=1.4.3" notebook = ">=5.6.0" [[package]] -category = "main" -description = "A fast implementation of the Cassowary constraint solver" name = "kiwisolver" +version = "1.3.1" +description = "A fast implementation of the Cassowary constraint solver" +category = "main" optional = false python-versions = ">=3.6" -version = "1.3.1" [[package]] -category = "main" -description = "Python LESS compiler" name = "lesscpy" +version = "0.14.0" +description = "Python LESS compiler" +category = "main" optional = false python-versions = "*" -version = "0.14.0" [package.dependencies] ply = "*" six = "*" [[package]] -category = "main" -description = "Python logging made (stupidly) simple" name = "loguru" +version = "0.5.3" +description = "Python logging made (stupidly) simple" +category = "main" optional = false python-versions = ">=3.5" -version = "0.5.3" [package.dependencies] -colorama = ">=0.3.4" -win32-setctime = ">=1.0.0" +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"] [[package]] -category = "main" -description = "Safely add untrusted strings to HTML/XML markup." name = "markupsafe" +version = "1.1.1" +description = "Safely add untrusted strings to HTML/XML markup." +category = "main" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*" -version = "1.1.1" [[package]] -category = "main" -description = "A lightweight library for converting complex datatypes to and from native Python datatypes." name = "marshmallow" +version = "3.9.1" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +category = "main" optional = false python-versions = ">=3.5" -version = "3.9.1" [package.extras] -dev = ["pytest", "pytz", "simplejson", "mypy (0.790)", "flake8 (3.8.4)", "flake8-bugbear (20.1.4)", "pre-commit (>=2.4,<3.0)", "tox"] -docs = ["sphinx (3.3.0)", "sphinx-issues (1.2.0)", "alabaster (0.7.12)", "sphinx-version-warning (1.1.2)", "autodocsumm (0.2.1)"] -lint = ["mypy (0.790)", "flake8 (3.8.4)", "flake8-bugbear (20.1.4)", "pre-commit (>=2.4,<3.0)"] +dev = ["pytest", "pytz", "simplejson", "mypy (==0.790)", "flake8 (==3.8.4)", "flake8-bugbear (==20.1.4)", "pre-commit (>=2.4,<3.0)", "tox"] +docs = ["sphinx (==3.3.0)", "sphinx-issues (==1.2.0)", "alabaster (==0.7.12)", "sphinx-version-warning (==1.1.2)", "autodocsumm (==0.2.1)"] +lint = ["mypy (==0.790)", "flake8 (==3.8.4)", "flake8-bugbear (==20.1.4)", "pre-commit (>=2.4,<3.0)"] tests = ["pytest", "pytz", "simplejson"] [[package]] -category = "main" -description = "Python plotting package" name = "matplotlib" +version = "3.3.3" +description = "Python plotting package" +category = "main" optional = false python-versions = ">=3.6" -version = "3.3.3" [package.dependencies] cycler = ">=0.10" @@ -831,36 +820,36 @@ pyparsing = ">=2.0.3,<2.0.4 || >2.0.4,<2.1.2 || >2.1.2,<2.1.6 || >2.1.6" python-dateutil = ">=2.1" [[package]] -category = "main" -description = "McCabe checker, plugin for flake8" name = "mccabe" +version = "0.6.1" +description = "McCabe checker, plugin for flake8" +category = "main" optional = false python-versions = "*" -version = "0.6.1" [[package]] -category = "main" -description = "The fastest markdown parser in pure Python" name = "mistune" +version = "0.8.4" +description = "The fastest markdown parser in pure Python" +category = "main" optional = false python-versions = "*" -version = "0.8.4" [[package]] -category = "main" -description = "More routines for operating on iterables, beyond itertools" name = "more-itertools" +version = "8.6.0" +description = "More routines for operating on iterables, beyond itertools" +category = "main" optional = false python-versions = ">=3.5" -version = "8.6.0" [[package]] -category = "dev" -description = "Optional static typing for Python" name = "mypy" +version = "0.770" +description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.5" -version = "0.770" [package.dependencies] mypy-extensions = ">=0.4.3,<0.5.0" @@ -871,20 +860,20 @@ typing-extensions = ">=3.7.4" dmypy = ["psutil (>=4.0)"] [[package]] -category = "main" -description = "Experimental type system extensions for programs checked with the mypy typechecker." name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "main" optional = false python-versions = "*" -version = "0.4.3" [[package]] -category = "main" -description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." name = "nbclient" +version = "0.5.1" +description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "main" optional = false python-versions = ">=3.6" -version = "0.5.1" [package.dependencies] async-generator = "*" @@ -899,12 +888,12 @@ sphinx = ["Sphinx (>=1.7)", "sphinx-book-theme", "mock", "moto", "myst-parser"] test = ["codecov", "coverage", "ipython", "ipykernel", "ipywidgets", "pytest (>=4.1)", "pytest-cov (>=2.6.1)", "check-manifest", "flake8", "mypy", "tox", "bumpversion", "xmltodict", "pip (>=18.1)", "wheel (>=0.31.0)", "setuptools (>=38.6.0)", "twine (>=1.11.0)", "black"] [[package]] -category = "main" -description = "Converting Jupyter Notebooks" name = "nbconvert" +version = "6.0.7" +description = "Converting Jupyter Notebooks" +category = "main" optional = false python-versions = ">=3.6" -version = "6.0.7" [package.dependencies] bleach = "*" @@ -922,19 +911,19 @@ testpath = "*" traitlets = ">=4.2" [package.extras] -all = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (0.2.2)", "tornado (>=4.0)", "sphinx (>=1.5.1)", "sphinx-rtd-theme", "nbsphinx (>=0.2.12)", "ipython"] +all = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (==0.2.2)", "tornado (>=4.0)", "sphinx (>=1.5.1)", "sphinx-rtd-theme", "nbsphinx (>=0.2.12)", "ipython"] docs = ["sphinx (>=1.5.1)", "sphinx-rtd-theme", "nbsphinx (>=0.2.12)", "ipython"] serve = ["tornado (>=4.0)"] -test = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (0.2.2)"] -webpdf = ["pyppeteer (0.2.2)"] +test = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (==0.2.2)"] +webpdf = ["pyppeteer (==0.2.2)"] [[package]] -category = "main" -description = "The Jupyter Notebook format" name = "nbformat" +version = "5.0.8" +description = "The Jupyter Notebook format" +category = "main" optional = false python-versions = ">=3.5" -version = "5.0.8" [package.dependencies] ipython-genutils = "*" @@ -947,20 +936,20 @@ fast = ["fastjsonschema"] test = ["fastjsonschema", "testpath", "pytest", "pytest-cov"] [[package]] -category = "main" -description = "Patch asyncio to allow nested event loops" name = "nest-asyncio" +version = "1.4.3" +description = "Patch asyncio to allow nested event loops" +category = "main" optional = false python-versions = ">=3.5" -version = "1.4.3" [[package]] -category = "main" -description = "Natural Language Toolkit" name = "nltk" +version = "3.5" +description = "Natural Language Toolkit" +category = "main" optional = false python-versions = "*" -version = "3.5" [package.dependencies] click = "*" @@ -977,15 +966,14 @@ tgrep = ["pyparsing"] twitter = ["twython"] [[package]] -category = "main" -description = "A web-based notebook environment for interactive computing" name = "notebook" +version = "6.1.5" +description = "A web-based notebook environment for interactive computing" +category = "main" optional = false python-versions = ">=3.5" -version = "6.1.5" [package.dependencies] -Send2Trash = "*" argon2-cffi = "*" ipykernel = "*" ipython-genutils = "*" @@ -996,6 +984,7 @@ nbconvert = "*" nbformat = "*" prometheus-client = "*" pyzmq = ">=17" +Send2Trash = "*" terminado = ">=0.8.3" tornado = ">=5.0" traitlets = ">=4.2.1" @@ -1005,161 +994,160 @@ docs = ["sphinx", "nbsphinx", "sphinxcontrib-github-alt"] test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-exclude", "selenium", "pytest", "pytest-cov", "requests-unixsocket"] [[package]] -category = "main" -description = "NumPy is the fundamental package for array computing with Python." name = "numpy" +version = "1.19.4" +description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.6" -version = "1.19.4" [[package]] -category = "dev" -description = "Python Bindings for the NVIDIA Management Library" name = "nvidia-ml-py3" +version = "7.352.0" +description = "Python Bindings for the NVIDIA Management Library" +category = "dev" optional = false python-versions = "*" -version = "7.352.0" [[package]] -category = "main" -description = "A flexible configuration library" name = "omegaconf" +version = "2.0.5" +description = "A flexible configuration library" +category = "main" optional = false python-versions = ">=3.6" -version = "2.0.5" [package.dependencies] PyYAML = ">=5.1" typing-extensions = "*" [[package]] -category = "main" -description = "Wrapper package for OpenCV python bindings." name = "opencv-python" +version = "4.4.0.46" +description = "Wrapper package for OpenCV python bindings." +category = "main" optional = false python-versions = ">=3.6" -version = "4.4.0.46" [[package]] -category = "main" -description = "Core utilities for Python packages" name = "packaging" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" version = "20.4" +description = "Core utilities for Python packages" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [package.dependencies] pyparsing = ">=2.0.2" six = "*" [[package]] -category = "main" -description = "Utilities for writing pandoc filters in python" name = "pandocfilters" +version = "1.4.3" +description = "Utilities for writing pandoc filters in python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.4.3" [[package]] -category = "main" -description = "A Python Parser" name = "parso" +version = "0.7.1" +description = "A Python Parser" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "0.7.1" [package.extras] testing = ["docopt", "pytest (>=3.0.7)"] [[package]] -category = "dev" -description = "Utility library for gitignore style pattern matching of file paths." name = "pathspec" +version = "0.8.1" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.8.1" [[package]] -category = "main" -description = "File system general utilities" name = "pathtools" +version = "0.1.2" +description = "File system general utilities" +category = "main" optional = false python-versions = "*" -version = "0.1.2" [[package]] -category = "dev" -description = "Python Build Reasonableness" name = "pbr" +version = "5.5.1" +description = "Python Build Reasonableness" +category = "dev" optional = false python-versions = ">=2.6" -version = "5.5.1" [[package]] -category = "main" -description = "Pexpect allows easy control of interactive console applications." -marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\"" name = "pexpect" +version = "4.8.0" +description = "Pexpect allows easy control of interactive console applications." +category = "main" optional = false python-versions = "*" -version = "4.8.0" [package.dependencies] ptyprocess = ">=0.5" [[package]] -category = "main" -description = "Tiny 'shelve'-like database with concurrency support" name = "pickleshare" +version = "0.7.5" +description = "Tiny 'shelve'-like database with concurrency support" +category = "main" optional = false python-versions = "*" -version = "0.7.5" [[package]] -category = "main" -description = "Python Imaging Library (Fork)" name = "pillow" +version = "8.0.1" +description = "Python Imaging Library (Fork)" +category = "main" optional = false python-versions = ">=3.6" -version = "8.0.1" [[package]] -category = "main" -description = "plugin and hook calling mechanisms for python" name = "pluggy" +version = "0.13.1" +description = "plugin and hook calling mechanisms for python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "0.13.1" [package.extras] dev = ["pre-commit", "tox"] [[package]] -category = "main" -description = "Python Lex & Yacc" name = "ply" +version = "3.11" +description = "Python Lex & Yacc" +category = "main" optional = false python-versions = "*" -version = "3.11" [[package]] -category = "main" -description = "Python client for the Prometheus monitoring system." name = "prometheus-client" +version = "0.9.0" +description = "Python client for the Prometheus monitoring system." +category = "main" optional = false python-versions = "*" -version = "0.9.0" [package.extras] twisted = ["twisted"] [[package]] -category = "main" -description = "Promises/A+ implementation for Python" name = "promise" +version = "2.3" +description = "Promises/A+ implementation for Python" +category = "main" optional = false python-versions = "*" -version = "2.3" [package.dependencies] six = "*" @@ -1168,126 +1156,125 @@ six = "*" test = ["pytest (>=2.7.3)", "pytest-cov", "coveralls", "futures", "pytest-benchmark", "mock"] [[package]] -category = "main" -description = "Library for building powerful interactive command lines in Python" name = "prompt-toolkit" +version = "3.0.8" +description = "Library for building powerful interactive command lines in Python" +category = "main" optional = false python-versions = ">=3.6.1" -version = "3.0.8" [package.dependencies] wcwidth = "*" [[package]] -category = "main" -description = "Protocol Buffers" name = "protobuf" +version = "3.14.0" +description = "Protocol Buffers" +category = "main" optional = false python-versions = "*" -version = "3.14.0" [package.dependencies] six = ">=1.9" [[package]] -category = "main" -description = "Cross-platform lib for process and system monitoring in Python." name = "psutil" +version = "5.7.3" +description = "Cross-platform lib for process and system monitoring in Python." +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "5.7.3" [package.extras] test = ["ipaddress", "mock", "unittest2", "enum34", "pywin32", "wmi"] [[package]] -category = "main" -description = "Run a subprocess in a pseudo terminal" -marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\" or os_name != \"nt\" or python_version >= \"3.3\" and sys_platform != \"win32\" and (python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\")" name = "ptyprocess" +version = "0.6.0" +description = "Run a subprocess in a pseudo terminal" +category = "main" optional = false python-versions = "*" -version = "0.6.0" [[package]] -category = "main" -description = "library with cross-python path, ini-parsing, io, code, log facilities" name = "py" +version = "1.9.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.9.0" [[package]] -category = "main" -description = "Python style guide checker" name = "pycodestyle" +version = "2.6.0" +description = "Python style guide checker" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "2.6.0" [[package]] -category = "main" -description = "C parser in Python" name = "pycparser" +version = "2.20" +description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "2.20" [[package]] -category = "main" -description = "Python docstring style checker" name = "pydocstyle" +version = "5.1.1" +description = "Python docstring style checker" +category = "main" optional = false python-versions = ">=3.5" -version = "5.1.1" [package.dependencies] snowballstemmer = "*" [[package]] -category = "main" -description = "passive checker of Python programs" name = "pyflakes" +version = "2.2.0" +description = "passive checker of Python programs" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "2.2.0" [[package]] -category = "main" -description = "Pygments is a syntax highlighting package written in Python." name = "pygments" +version = "2.7.2" +description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.5" -version = "2.7.2" [[package]] -category = "main" -description = "Python parsing module" name = "pyparsing" +version = "2.4.7" +description = "Python parsing module" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -version = "2.4.7" [[package]] -category = "main" -description = "Persistent/Functional/Immutable data structures" name = "pyrsistent" +version = "0.17.3" +description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.5" -version = "0.17.3" [[package]] -category = "main" -description = "pytest: simple powerful testing with Python" name = "pytest" +version = "5.4.3" +description = "pytest: simple powerful testing with Python" +category = "main" optional = false python-versions = ">=3.5" -version = "5.4.3" [package.dependencies] -atomicwrites = ">=1.0" +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} attrs = ">=17.4.0" -colorama = "*" +colorama = {version = "*", markers = "sys_platform == \"win32\""} more-itertools = ">=4.0.0" packaging = "*" pluggy = ">=0.12,<1.0" @@ -1295,31 +1282,31 @@ py = ">=1.5.0" wcwidth = "*" [package.extras] -checkqa-mypy = ["mypy (v0.761)"] +checkqa-mypy = ["mypy (==v0.761)"] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] [[package]] -category = "dev" -description = "Pytest plugin for measuring coverage." name = "pytest-cov" +version = "2.10.1" +description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "2.10.1" [package.dependencies] coverage = ">=4.4" pytest = ">=4.6" [package.extras] -testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"] +testing = ["fields", "hunter", "process-tests (==2.0.2)", "six", "pytest-xdist", "virtualenv"] [[package]] -category = "dev" -description = "Thin-wrapper around the mock package for easier use with pytest" name = "pytest-mock" +version = "3.3.1" +description = "Thin-wrapper around the mock package for easier use with pytest" +category = "dev" optional = false python-versions = ">=3.5" -version = "3.3.1" [package.dependencies] pytest = ">=5.0" @@ -1328,34 +1315,31 @@ pytest = ">=5.0" dev = ["pre-commit", "tox", "pytest-asyncio"] [[package]] -category = "main" -description = "Extensions to the standard Python datetime module" name = "python-dateutil" +version = "2.8.1" +description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -version = "2.8.1" [package.dependencies] six = ">=1.5" [[package]] -category = "main" -description = "Python extension for computing string edit distances and similarities." name = "python-levenshtein" +version = "0.12.0" +description = "Python extension for computing string edit distances and similarities." +category = "main" optional = false python-versions = "*" -version = "0.12.0" - -[package.dependencies] -setuptools = "*" [[package]] -category = "main" -description = "The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch." name = "pytorch-metric-learning" +version = "0.9.94" +description = "The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch." +category = "main" optional = false python-versions = ">=3.0" -version = "0.9.94" [package.dependencies] numpy = "*" @@ -1370,58 +1354,56 @@ with-hooks = ["record-keeper (>=0.9.29)", "faiss-gpu (>=1.6.3)", "tensorboard"] with-hooks-cpu = ["record-keeper (>=0.9.29)", "faiss-cpu (>=1.6.3)", "tensorboard"] [[package]] -category = "main" -description = "World timezone definitions, modern and historical" name = "pytz" +version = "2020.4" +description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" -version = "2020.4" [[package]] -category = "main" -description = "Python for Window Extensions" -marker = "sys_platform == \"win32\"" name = "pywin32" +version = "300" +description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" -version = "300" [[package]] -category = "main" -description = "Python bindings for the winpty library" -marker = "os_name == \"nt\"" name = "pywinpty" +version = "0.5.7" +description = "Python bindings for the winpty library" +category = "main" optional = false python-versions = "*" -version = "0.5.7" [[package]] -category = "main" -description = "YAML parser and emitter for Python" name = "pyyaml" +version = "5.3.1" +description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "5.3.1" [[package]] -category = "main" -description = "Python bindings for 0MQ" name = "pyzmq" +version = "20.0.0" +description = "Python bindings for 0MQ" +category = "main" optional = false python-versions = ">=3.5" -version = "20.0.0" [package.dependencies] -cffi = "*" -py = "*" +cffi = {version = "*", markers = "implementation_name === \"pypy\""} +py = {version = "*", markers = "implementation_name === \"pypy\""} [[package]] -category = "dev" -description = "Jupyter Qt console" name = "qtconsole" +version = "4.7.7" +description = "Jupyter Qt console" +category = "dev" optional = false python-versions = "*" -version = "4.7.7" [package.dependencies] ipykernel = ">=4.1" @@ -1438,50 +1420,50 @@ doc = ["Sphinx (>=1.3)"] test = ["pytest", "mock"] [[package]] -category = "dev" -description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5, PyQt4 and PySide) and additional custom QWidgets." name = "qtpy" +version = "1.9.0" +description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5, PyQt4 and PySide) and additional custom QWidgets." +category = "dev" optional = false python-versions = "*" -version = "1.9.0" [[package]] -category = "dev" -description = "Python client for Redis key-value store" name = "redis" +version = "3.5.3" +description = "Python client for Redis key-value store" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "3.5.3" [package.extras] hiredis = ["hiredis (>=0.1.3)"] [[package]] -category = "dev" -description = "Redis locking mechanism" name = "redlock-py" +version = "1.0.8" +description = "Redis locking mechanism" +category = "dev" optional = false python-versions = "*" -version = "1.0.8" [package.dependencies] redis = "*" [[package]] -category = "main" -description = "Alternative regular expression module, to replace re." name = "regex" +version = "2020.11.13" +description = "Alternative regular expression module, to replace re." +category = "main" optional = false python-versions = "*" -version = "2020.11.13" [[package]] -category = "main" -description = "Python HTTP for Humans." name = "requests" +version = "2.25.0" +description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "2.25.0" [package.dependencies] certifi = ">=2017.4.17" @@ -1491,30 +1473,29 @@ urllib3 = ">=1.21.1,<1.27" [package.extras] security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] -socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"] +socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] [[package]] -category = "dev" -description = "Checks installed dependencies for known vulnerabilities." name = "safety" +version = "1.9.0" +description = "Checks installed dependencies for known vulnerabilities." +category = "dev" optional = false python-versions = ">=3.5" -version = "1.9.0" [package.dependencies] Click = ">=6.0" dparse = ">=0.5.1" packaging = "*" requests = "*" -setuptools = "*" [[package]] -category = "main" -description = "A set of python modules for machine learning and data mining" name = "scikit-learn" +version = "0.23.2" +description = "A set of python modules for machine learning and data mining" +category = "main" optional = false python-versions = ">=3.6" -version = "0.23.2" [package.dependencies] joblib = ">=0.11" @@ -1526,39 +1507,39 @@ threadpoolctl = ">=2.0.0" alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"] [[package]] -category = "main" -description = "SciPy: Scientific Library for Python" name = "scipy" +version = "1.5.4" +description = "SciPy: Scientific Library for Python" +category = "main" optional = false python-versions = ">=3.6" -version = "1.5.4" [package.dependencies] numpy = ">=1.14.5" [[package]] -category = "main" -description = "Send file to trash natively under Mac OS X, Windows and Linux." name = "send2trash" +version = "1.5.0" +description = "Send file to trash natively under Mac OS X, Windows and Linux." +category = "main" optional = false python-versions = "*" -version = "1.5.0" [[package]] -category = "main" -description = "SentencePiece python wrapper" name = "sentencepiece" +version = "0.1.95" +description = "SentencePiece python wrapper" +category = "main" optional = false python-versions = "*" -version = "0.1.95" [[package]] -category = "main" -description = "Python client for Sentry (https://sentry.io)" name = "sentry-sdk" +version = "0.19.4" +description = "Python client for Sentry (https://sentry.io)" +category = "main" optional = false python-versions = "*" -version = "0.19.4" [package.dependencies] certifi = "*" @@ -1581,56 +1562,55 @@ sqlalchemy = ["sqlalchemy (>=1.2)"] tornado = ["tornado (>=5)"] [[package]] -category = "main" -description = "A generator library for concise, unambiguous and URL-safe UUIDs." name = "shortuuid" +version = "1.0.1" +description = "A generator library for concise, unambiguous and URL-safe UUIDs." +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.1" [[package]] -category = "main" -description = "Python 2 and 3 compatibility utilities" name = "six" +version = "1.15.0" +description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" -version = "1.15.0" [[package]] -category = "main" -description = "A pure Python implementation of a sliding window memory map manager" name = "smmap" +version = "3.0.4" +description = "A pure Python implementation of a sliding window memory map manager" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "3.0.4" [[package]] -category = "main" -description = "This package provides 26 stemmers for 25 languages generated from Snowball algorithms." name = "snowballstemmer" +version = "2.0.0" +description = "This package provides 26 stemmers for 25 languages generated from Snowball algorithms." +category = "main" optional = false python-versions = "*" -version = "2.0.0" [[package]] -category = "main" -description = "Python documentation generator" name = "sphinx" +version = "3.3.1" +description = "Python documentation generator" +category = "main" optional = false python-versions = ">=3.5" -version = "3.3.1" [package.dependencies] -Jinja2 = ">=2.3" -Pygments = ">=2.0" alabaster = ">=0.7,<0.8" babel = ">=1.3" -colorama = ">=0.3.5" +colorama = {version = ">=0.3.5", markers = "sys_platform == \"win32\""} docutils = ">=0.12" imagesize = "*" +Jinja2 = ">=2.3" packaging = "*" +Pygments = ">=2.0" requests = ">=2.5.0" -setuptools = "*" snowballstemmer = ">=1.1" sphinxcontrib-applehelp = "*" sphinxcontrib-devhelp = "*" @@ -1645,12 +1625,12 @@ lint = ["flake8 (>=3.5.0)", "flake8-import-order", "mypy (>=0.790)", "docutils-s test = ["pytest", "pytest-cov", "html5lib", "typed-ast", "cython"] [[package]] -category = "main" -description = "Type hints (PEP 484) support for the Sphinx autodoc extension" name = "sphinx-autodoc-typehints" +version = "1.11.1" +description = "Type hints (PEP 484) support for the Sphinx autodoc extension" +category = "main" optional = false python-versions = ">=3.5.2" -version = "1.11.1" [package.dependencies] Sphinx = ">=3.0" @@ -1660,153 +1640,153 @@ test = ["pytest (>=3.1.0)", "typing-extensions (>=3.5)", "sphobjinv (>=2.0)", "S type_comments = ["typed-ast (>=1.4.0)"] [[package]] -category = "main" -description = "Read the Docs theme for Sphinx" name = "sphinx-rtd-theme" +version = "0.4.3" +description = "Read the Docs theme for Sphinx" +category = "main" optional = false python-versions = "*" -version = "0.4.3" [package.dependencies] sphinx = "*" [[package]] -category = "main" -description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books" name = "sphinxcontrib-applehelp" +version = "1.0.2" +description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books" +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.2" [package.extras] lint = ["flake8", "mypy", "docutils-stubs"] test = ["pytest"] [[package]] -category = "main" -description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." name = "sphinxcontrib-devhelp" +version = "1.0.2" +description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.2" [package.extras] lint = ["flake8", "mypy", "docutils-stubs"] test = ["pytest"] [[package]] -category = "main" -description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" name = "sphinxcontrib-htmlhelp" +version = "1.0.3" +description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.3" [package.extras] lint = ["flake8", "mypy", "docutils-stubs"] test = ["pytest", "html5lib"] [[package]] -category = "main" -description = "A sphinx extension which renders display math in HTML via JavaScript" name = "sphinxcontrib-jsmath" +version = "1.0.1" +description = "A sphinx extension which renders display math in HTML via JavaScript" +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.1" [package.extras] test = ["pytest", "flake8", "mypy"] [[package]] -category = "main" -description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." name = "sphinxcontrib-qthelp" +version = "1.0.3" +description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.3" [package.extras] lint = ["flake8", "mypy", "docutils-stubs"] test = ["pytest"] [[package]] -category = "main" -description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." name = "sphinxcontrib-serializinghtml" +version = "1.1.4" +description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." +category = "main" optional = false python-versions = ">=3.5" -version = "1.1.4" [package.extras] lint = ["flake8", "mypy", "docutils-stubs"] test = ["pytest"] [[package]] -category = "dev" -description = "Manage dynamic plugins for Python applications" name = "stevedore" +version = "3.2.2" +description = "Manage dynamic plugins for Python applications" +category = "dev" optional = false python-versions = ">=3.6" -version = "3.2.2" [package.dependencies] pbr = ">=2.0.0,<2.1.0 || >2.1.0" [[package]] -category = "main" -description = "A backport of the subprocess module from Python 3 for use on 2.x." name = "subprocess32" +version = "3.5.4" +description = "A backport of the subprocess module from Python 3 for use on 2.x." +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, <4" -version = "3.5.4" [[package]] -category = "main" -description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." name = "terminado" +version = "0.9.1" +description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "main" optional = false python-versions = ">=3.6" -version = "0.9.1" [package.dependencies] -ptyprocess = "*" -pywinpty = ">=0.5" +ptyprocess = {version = "*", markers = "os_name != \"nt\""} +pywinpty = {version = ">=0.5", markers = "os_name == \"nt\""} tornado = ">=4" [[package]] -category = "main" -description = "Test utilities for code working with files and commands" name = "testpath" +version = "0.4.4" +description = "Test utilities for code working with files and commands" +category = "main" optional = false python-versions = "*" -version = "0.4.4" [package.extras] test = ["pathlib2"] [[package]] -category = "main" -description = "threadpoolctl" name = "threadpoolctl" +version = "2.1.0" +description = "threadpoolctl" +category = "main" optional = false python-versions = ">=3.5" -version = "2.1.0" [[package]] -category = "main" -description = "Python Library for Tom's Obvious, Minimal Language" name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -version = "0.10.2" [[package]] -category = "main" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" name = "torch" +version = "1.7.0" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" optional = false python-versions = ">=3.6.1" -version = "1.7.0" [package.dependencies] dataclasses = "*" @@ -1815,20 +1795,20 @@ numpy = "*" typing-extensions = "*" [[package]] -category = "main" -description = "Model summary in PyTorch, based off of the original torchsummary." name = "torch-summary" +version = "1.4.3" +description = "Model summary in PyTorch, based off of the original torchsummary." +category = "main" optional = false python-versions = ">=3.5" -version = "1.4.3" [[package]] -category = "main" -description = "image and video datasets and models for torch deep learning" name = "torchvision" +version = "0.8.1" +description = "image and video datasets and models for torch deep learning" +category = "main" optional = false python-versions = "*" -version = "0.8.1" [package.dependencies] numpy = "*" @@ -1839,31 +1819,31 @@ torch = "1.7.0" scipy = ["scipy"] [[package]] -category = "main" -description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." name = "tornado" +version = "6.1" +description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "main" optional = false python-versions = ">= 3.5" -version = "6.1" [[package]] -category = "main" -description = "Fast, Extensible Progress Meter" name = "tqdm" +version = "4.53.0" +description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" -version = "4.53.0" [package.extras] dev = ["py-make (>=0.1.0)", "twine", "argopt", "pydoc-markdown", "wheel"] [[package]] -category = "main" -description = "Traitlets Python configuration system" name = "traitlets" +version = "5.0.5" +description = "Traitlets Python configuration system" +category = "main" optional = false python-versions = ">=3.7" -version = "5.0.5" [package.dependencies] ipython-genutils = "*" @@ -1872,76 +1852,76 @@ ipython-genutils = "*" test = ["pytest"] [[package]] -category = "dev" -description = "a fork of Python 2 and 3 ast modules with type comment support" name = "typed-ast" +version = "1.4.1" +description = "a fork of Python 2 and 3 ast modules with type comment support" +category = "dev" optional = false python-versions = "*" -version = "1.4.1" [[package]] -category = "dev" -description = "Run-time type checker for Python" name = "typeguard" +version = "2.10.0" +description = "Run-time type checker for Python" +category = "dev" optional = false python-versions = ">=3.5.3" -version = "2.10.0" [package.extras] doc = ["sphinx-rtd-theme", "sphinx-autodoc-typehints (>=1.2.0)"] test = ["pytest", "typing-extensions"] [[package]] -category = "main" -description = "Backported and Experimental Type Hints for Python 3.5+" name = "typing-extensions" +version = "3.7.4.3" +description = "Backported and Experimental Type Hints for Python 3.5+" +category = "main" optional = false python-versions = "*" -version = "3.7.4.3" [[package]] -category = "main" -description = "Runtime inspection utilities for typing module." name = "typing-inspect" +version = "0.6.0" +description = "Runtime inspection utilities for typing module." +category = "main" optional = false python-versions = "*" -version = "0.6.0" [package.dependencies] mypy-extensions = ">=0.3.0" typing-extensions = ">=3.7.4" [[package]] -category = "main" -description = "HTTP library with thread-safe connection pooling, file post, and more." name = "urllib3" +version = "1.26.2" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" -version = "1.26.2" [package.extras] brotli = ["brotlipy (>=0.6.0)"] secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] -socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7,<2.0)"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] -category = "main" -description = "A CLI and library for interacting with the Weights and Biases API." name = "wandb" +version = "0.10.12" +description = "A CLI and library for interacting with the Weights and Biases API." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "0.10.12" [package.dependencies] Click = ">=7.0" -GitPython = ">=1.0.0" -PyYAML = "*" configparser = ">=3.8.1" docker-pycreds = ">=0.4.0" +GitPython = ">=1.0.0" promise = ">=2.0,<3" protobuf = ">=3.12.0" psutil = ">=5.0.0" python-dateutil = ">=2.6.1" +PyYAML = "*" requests = ">=2.0.0,<3" sentry-sdk = ">=0.4.0" shortuuid = ">=0.5.0" @@ -1952,16 +1932,16 @@ watchdog = ">=0.8.3" [package.extras] aws = ["boto3"] gcp = ["google-cloud-storage"] -grpc = ["grpcio (1.27.2)"] +grpc = ["grpcio (==1.27.2)"] kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"] [[package]] -category = "main" -description = "Filesystem events monitoring" name = "watchdog" +version = "0.10.4" +description = "Filesystem events monitoring" +category = "main" optional = false python-versions = "*" -version = "0.10.4" [package.dependencies] pathtools = ">=0.1.1" @@ -1970,51 +1950,50 @@ pathtools = ">=0.1.1" watchmedo = ["PyYAML (>=3.10)", "argh (>=0.24.1)"] [[package]] -category = "main" -description = "Measures the displayed width of unicode strings in a terminal" name = "wcwidth" +version = "0.2.5" +description = "Measures the displayed width of unicode strings in a terminal" +category = "main" optional = false python-versions = "*" -version = "0.2.5" [[package]] -category = "main" -description = "Character encoding aliases for legacy web content" name = "webencodings" +version = "0.5.1" +description = "Character encoding aliases for legacy web content" +category = "main" optional = false python-versions = "*" -version = "0.5.1" [[package]] -category = "dev" -description = "IPython HTML widgets for Jupyter" name = "widgetsnbextension" +version = "3.5.1" +description = "IPython HTML widgets for Jupyter" +category = "dev" optional = false python-versions = "*" -version = "3.5.1" [package.dependencies] notebook = ">=4.4.1" [[package]] -category = "main" -description = "A small Python utility to set file creation time on Windows" -marker = "sys_platform == \"win32\"" name = "win32-setctime" +version = "1.0.3" +description = "A small Python utility to set file creation time on Windows" +category = "main" optional = false python-versions = ">=3.5" -version = "1.0.3" [package.extras] dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"] [[package]] -category = "dev" -description = "A rewrite of the builtin doctest module" name = "xdoctest" +version = "0.12.0" +description = "A rewrite of the builtin doctest module" +category = "dev" optional = false python-versions = "*" -version = "0.12.0" [package.dependencies] six = "*" @@ -2025,9 +2004,9 @@ optional = ["pygments", "colorama"] tests = ["pytest", "pytest-cov", "codecov", "scikit-build", "cmake", "ninja", "pybind11"] [metadata] -content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0" -lock-version = "1.0" +lock-version = "1.1" python-versions = "^3.8" +content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0" [metadata.files] alabaster = [ diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 0e4b298..2d6b43c 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -25,16 +25,71 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks import CNN" + "from text_recognizer.networks import CNN, TDS2d" ] }, { "cell_type": "code", - "execution_count": 63, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tds2d = TDS2d(**{\n", + " \"depth\" : 4,\n", + " \"tds_groups\" : [\n", + " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", + " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", + " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", + " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n", + " ],\n", + " \"kernel_size\" : [5, 7],\n", + " \"dropout_rate\" : 0.1\n", + " }, input_dim=32, output_dim=128)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tds2d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(2,1, 28, 952)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tds2d(t).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -52,54 +107,25 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (1): Sequential(\n", - " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - ")" - ] - }, - "execution_count": 81, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "nn.Sequential(i,i)" ] }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 128, 1, 59])" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "cnn(t).shape" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -108,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -117,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -135,160 +161,34 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "29.5" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "5 * 59 / 10" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1, 28, 952])" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "x.shape" ] }, { "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "===============================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "===============================================================================================\n", - "├─Encoder: 1-1 [-1, 32, 5, 59] --\n", - "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n", - "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n", - "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n", - "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n", - "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n", - "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n", - "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n", - "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n", - "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n", - "├─Decoder: 1-2 [-1, 1, 28, 952] --\n", - "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n", - "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n", - "| └─Sequential: 2 [] --\n", - "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n", - "| └─Sequential: 2 [] --\n", - "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n", - "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n", - "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n", - "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n", - "| └─Sequential: 2 [] --\n", - "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n", - "| └─Sequential: 2 [] --\n", - "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n", - "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n", - "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n", - "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n", - "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n", - "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n", - "===============================================================================================\n", - "Total params: 4,343,905\n", - "Trainable params: 4,343,905\n", - "Non-trainable params: 0\n", - "Total mult-adds (G): 1.76\n", - "===============================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 9.32\n", - "Params size (MB): 16.57\n", - "Estimated Total Size (MB): 26.00\n", - "===============================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "===============================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "===============================================================================================\n", - "├─Encoder: 1-1 [-1, 32, 5, 59] --\n", - "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n", - "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n", - "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n", - "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n", - "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n", - "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n", - "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n", - "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n", - "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n", - "├─Decoder: 1-2 [-1, 1, 28, 952] --\n", - "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n", - "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n", - "| └─Sequential: 2 [] --\n", - "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n", - "| └─Sequential: 2 [] --\n", - "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n", - "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n", - "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n", - "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n", - "| └─Sequential: 2 [] --\n", - "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n", - "| └─Sequential: 2 [] --\n", - "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n", - "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n", - "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n", - "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n", - "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n", - "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n", - "===============================================================================================\n", - "Total params: 4,343,905\n", - "Trainable params: 4,343,905\n", - "Non-trainable params: 0\n", - "Total mult-adds (G): 1.76\n", - "===============================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 9.32\n", - "Params size (MB): 16.57\n", - "Estimated Total Size (MB): 26.00\n", - "===============================================================================================" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)" ] }, { "cell_type": "code", - "execution_count": 94, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -297,47 +197,25 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 32, 4, 59])" - ] - }, - "execution_count": 107, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "up(tt).shape" ] }, { "cell_type": "code", - "execution_count": 104, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 32, 1, 59])" - ] - }, - "execution_count": 104, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "tt.shape" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -353,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -362,27 +240,16 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 30, 2048])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "e(t).shape" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -391,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -401,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -410,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -419,75 +286,34 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "ModuleAttributeError", - "evalue": "'Embedding' object has no attribute 'device'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0memb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-N1c_zsdp-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 779\u001b[0m type(self).__name__, name))\n\u001b[1;32m 780\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleAttributeError\u001b[0m: 'Embedding' object has no attribute 'device'" - ] - } - ], + "outputs": [], "source": [ "emb.device" ] }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", - " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", - " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", - " ...,\n", - " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", - " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", - " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005]])" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ee" ] }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 256])" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ee.shape" ] }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -496,27 +322,16 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 10, 256])" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "t.shape" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -525,47 +340,25 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 12, 256])" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "t.shape" ] }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 256])" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "e.shape" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -574,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -583,7 +376,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -603,7 +396,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -612,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -621,90 +414,16 @@ }, { "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 64, 12, 474] --\n", - "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", - "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", - "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n", - "├─Sequential: 1-2 [-1, 68, 1, 30] --\n", - "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n", - "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n", - "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n", - "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n", - "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n", - "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n", - "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n", - "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n", - "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n", - "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n", - "==========================================================================================\n", - "Total params: 805,910\n", - "Trainable params: 805,910\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 21.05\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 5.55\n", - "Params size (MB): 3.07\n", - "Estimated Total Size (MB): 8.73\n", - "==========================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 64, 12, 474] --\n", - "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", - "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", - "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n", - "├─Sequential: 1-2 [-1, 68, 1, 30] --\n", - "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n", - "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n", - "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n", - "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n", - "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n", - "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n", - "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n", - "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n", - "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n", - "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n", - "==========================================================================================\n", - "Total params: 805,910\n", - "Trainable params: 805,910\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 21.05\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 5.55\n", - "Params size (MB): 3.07\n", - "Estimated Total Size (MB): 8.73\n", - "==========================================================================================" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -715,187 +434,25 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Sequential(\n", - " (0): SELU(inplace=True)\n", - " (1): Sequential(\n", - " (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): SELU(inplace=True)\n", - " (3): MaxPool2d(kernel_size=(2, 4), stride=2, padding=1, dilation=1, ceil_mode=False)\n", - " )\n", - " (2): Sequential(\n", - " (0): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " )\n", - " (1): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " )\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " )\n", - " )\n", - " )\n", - " (2): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " )\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " )\n", - " )\n", - " )\n", - " (3): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " )\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " )\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "backbone" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 64, 7, 237] --\n", - "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", - "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", - "├─SELU: 1-2 [-1, 64, 12, 474] --\n", - "├─Sequential: 1 [] --\n", - "| └─SELU: 2-3 [-1, 64, 12, 474] --\n", - "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n", - "├─Sequential: 1-3 [-1, 256, 1, 30] --\n", - "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n", - "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n", - "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n", - "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n", - "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n", - "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n", - "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n", - "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n", - "==========================================================================================\n", - "Total params: 2,471,488\n", - "Trainable params: 2,471,488\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 27.71\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 5.55\n", - "Params size (MB): 9.43\n", - "Estimated Total Size (MB): 15.08\n", - "==========================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 64, 7, 237] --\n", - "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", - "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", - "├─SELU: 1-2 [-1, 64, 12, 474] --\n", - "├─Sequential: 1 [] --\n", - "| └─SELU: 2-3 [-1, 64, 12, 474] --\n", - "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n", - "├─Sequential: 1-3 [-1, 256, 1, 30] --\n", - "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n", - "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n", - "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n", - "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n", - "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n", - "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n", - "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n", - "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n", - "==========================================================================================\n", - "Total params: 2,471,488\n", - "Trainable params: 2,471,488\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 27.71\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 5.55\n", - "Params size (MB): 9.43\n", - "Estimated Total Size (MB): 15.08\n", - "==========================================================================================" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -904,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -913,7 +470,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -922,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -931,7 +488,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -940,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -949,67 +506,34 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 119, 256, 1])" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "d.shape" ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 119, 256])" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "d.squeeze(3).shape" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 256, 4, 119])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "b.shape" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1018,47 +542,25 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "96" - ] - }, - "execution_count": 70, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "32 + 64" ] }, { "cell_type": "code", - "execution_count": 106, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "336" - ] - }, - "execution_count": 106, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "3 * 112" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1067,7 +569,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1076,62 +578,29 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 196, 128])" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape" ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 196, 128])" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 4, 196, 256])" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ " torch.cat(\n", " [\n", @@ -1144,27 +613,16 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "784" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "4 * 196" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1173,40 +631,18 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "torch.nonzero(target == 9, as_tuple=False)[0].item()" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 1, 1, 12, 1, 1, 1, 1, 1, 9])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "target[:9]" ] @@ -1220,27 +656,16 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "inf" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "np.inf" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1249,22 +674,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "plt.figure(figsize=(15, 5))\n", "pe = PositionalEncoding(20, 0)\n", @@ -1276,7 +688,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1285,7 +697,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1294,110 +706,25 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "27.0" - ] - }, - "execution_count": 58, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "216 / 8" ] }, { "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 80] --\n", - "| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n", - "| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n", - "| └─ReLU: 2-3 [-1, 24, 28, 952] --\n", - "| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n", - "| └─_Transition: 2-5 [-1, 48, 14, 476] --\n", - "| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n", - "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n", - "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n", - "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n", - "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n", - "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n", - "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n", - "| └─Rearrange: 2-11 [-1, 216] --\n", - "| └─Linear: 2-12 [-1, 80] 17,360\n", - "==========================================================================================\n", - "Total params: 41,240\n", - "Trainable params: 41,240\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 252.43\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 53.69\n", - "Params size (MB): 0.16\n", - "Estimated Total Size (MB): 53.95\n", - "==========================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 80] --\n", - "| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n", - "| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n", - "| └─ReLU: 2-3 [-1, 24, 28, 952] --\n", - "| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n", - "| └─_Transition: 2-5 [-1, 48, 14, 476] --\n", - "| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n", - "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n", - "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n", - "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n", - "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n", - "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n", - "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n", - "| └─Rearrange: 2-11 [-1, 216] --\n", - "| └─Linear: 2-12 [-1, 80] 17,360\n", - "==========================================================================================\n", - "Total params: 41,240\n", - "Trainable params: 41,240\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 252.43\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 53.69\n", - "Params size (MB): 0.16\n", - "Estimated Total Size (MB): 53.95\n", - "==========================================================================================" - ] - }, - "execution_count": 59, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)" ] }, { "cell_type": "code", - "execution_count": 84, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1408,27 +735,16 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Sequential()" - ] - }, - "execution_count": 85, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "backbone" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1437,7 +753,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1455,74 +771,16 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 512, 2, 60] --\n", - "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n", - "| └─Sequential: 2-2 [-1, 32, 28, 952] 18,560\n", - "| └─Sequential: 2-3 [-1, 64, 14, 476] 57,536\n", - "| └─Sequential: 2-4 [-1, 128, 7, 238] 229,760\n", - "| └─Sequential: 2-5 [-1, 256, 4, 119] 918,272\n", - "| └─Sequential: 2-6 [-1, 512, 2, 60] 3,671,552\n", - "==========================================================================================\n", - "Total params: 4,895,968\n", - "Trainable params: 4,895,968\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 22.36\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 6.51\n", - "Params size (MB): 18.68\n", - "Estimated Total Size (MB): 25.29\n", - "==========================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 512, 2, 60] --\n", - "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n", - "| └─Sequential: 2-2 [-1, 32, 28, 952] 18,560\n", - "| └─Sequential: 2-3 [-1, 64, 14, 476] 57,536\n", - "| └─Sequential: 2-4 [-1, 128, 7, 238] 229,760\n", - "| └─Sequential: 2-5 [-1, 256, 4, 119] 918,272\n", - "| └─Sequential: 2-6 [-1, 512, 2, 60] 3,671,552\n", - "==========================================================================================\n", - "Total params: 4,895,968\n", - "Trainable params: 4,895,968\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 22.36\n", - "==========================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 6.51\n", - "Params size (MB): 18.68\n", - "Estimated Total Size (MB): 25.29\n", - "==========================================================================================" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "summary(w, (1, 28, 952), device=\"cpu\", depth=2)" ] }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1531,7 +789,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1541,7 +799,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1551,71 +809,34 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([100, 1, 256])" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "h.flatten(2).permute(2, 0, 1).shape" ] }, { "cell_type": "code", - "execution_count": 91, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([100, 1, 256])" - ] - }, - "execution_count": 91, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "h.flatten(2).permute(2, 0, 1).shape" ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0., -inf, -inf, -inf, -inf],\n", - " [0., 0., -inf, -inf, -inf],\n", - " [0., 0., 0., -inf, -inf],\n", - " [0., 0., 0., 0., -inf],\n", - " [0., 0., 0., 0., 0.]])" - ] - }, - "execution_count": 48, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "mask\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1625,7 +846,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1634,71 +855,34 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ True, True, True, True, True, True, False, False, False, False,\n", - " True, True, True, True, True, True, False, False, False, False,\n", - " True, True, True, True, True, True, False, False, False, False])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "mask" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 1, 21, 2, 45, 31, 81, 0, 0, 0, 0, 2, 1, 1, 1, 1, 81, 0, 0,\n", - " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pred * mask" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 1, 1, 1, 1, 1, 81, 0, 0, 0, 0, 1, 1, 1, 1, 1, 81, 0, 0,\n", - " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "target * mask" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1707,7 +891,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1716,7 +900,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1725,27 +909,16 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "30" - ] - }, - "execution_count": 66, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "target.shape[0]" ] }, { "cell_type": "code", - "execution_count": 84, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1754,27 +927,16 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([10, 20, 30])" - ] - }, - "execution_count": 85, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "t2" ] }, { "cell_type": "code", - "execution_count": 89, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1784,160 +946,72 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 1, 1, 1, 1, 1, 81, 79, 79, 79, 79, 2, 1, 1, 1, 1, 81, 79, 79,\n", - " 79, 79, 1, 1, 1, 1, 1, 81, 79, 79, 79, 79])" - ] - }, - "execution_count": 90, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pred" ] }, { "cell_type": "code", - "execution_count": 88, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "invalid syntax (, line 1)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m [pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" - ] - } - ], + "outputs": [], "source": [ "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]" ] }, { "cell_type": "code", - "execution_count": 69, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 6],\n", - " [ 7],\n", - " [ 8],\n", - " [ 9],\n", - " [16],\n", - " [17],\n", - " [18],\n", - " [19],\n", - " [26],\n", - " [27],\n", - " [28],\n", - " [29]])" - ] - }, - "execution_count": 69, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pad_indcies" ] }, { "cell_type": "code", - "execution_count": 71, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "only integer tensors of a single element can be converted to an index", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m79\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index" - ] - } - ], + "outputs": [], "source": [ "pred[pad_indcies:pad_indcies] = 79" ] }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([20])" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pred.shape" ] }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([20])" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "target.shape" ] }, { "cell_type": "code", - "execution_count": 91, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.0" - ] - }, - "execution_count": 91, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "accuracy(pred, target)" ] }, { "cell_type": "code", - "execution_count": 92, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1946,20 +1020,9 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.9667)" - ] - }, - "execution_count": 93, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "acc" ] diff --git a/src/notebooks/07-try-gtn.ipynb b/src/notebooks/07-try-gtn.ipynb index d366dec..4ef444b 100644 --- a/src/notebooks/07-try-gtn.ipynb +++ b/src/notebooks/07-try-gtn.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -123,6 +123,53 @@ "print(g1.grad().weights_to_list()) " ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1.0, 0.0, 0.5, 0.5]\n" + ] + } + ], + "source": [ + "import gtn\n", + "\n", + "# Make some graphs:\n", + "g1 = gtn.Graph()\n", + "g1.add_node(True) # Add a start node\n", + "g1.add_node() # Add an internal node\n", + "g1.add_node(False, True) # Add an accepting node\n", + "\n", + "# Add arcs with (src node, dst node, label):\n", + "g1.add_arc(0, 1, 1)\n", + "g1.add_arc(0, 1, 2)\n", + "g1.add_arc(1, 2, 1)\n", + "g1.add_arc(1, 2, 0)\n", + "\n", + "g2 = gtn.Graph()\n", + "g2.add_node(True, True)\n", + "g2.add_arc(0, 0, 1)\n", + "g2.add_arc(0, 0, 0)\n", + "\n", + "# Compute a function of the graphs:\n", + "intersection = gtn.intersect(g1, g2)\n", + "score = gtn.forward_score(intersection)\n", + "\n", + "# Visualize the intersected graph:\n", + "gtn.draw(intersection, \"intersection.pdf\")\n", + "\n", + "# Backprop:\n", + "gtn.backward(score)\n", + "\n", + "# Print gradients of arc weights \n", + "print(g1.grad().weights_to_list()) # [1.0, 0.0, 1.0, 0.0]" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb new file mode 100644 index 0000000..841a37d --- /dev/null +++ b/src/notebooks/Untitled.ipynb @@ -0,0 +1,385 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch\n", + "from torch import nn\n", + "\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.datasets import IamLinesDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "transform = [{\"type\": \"ToPILImage\", \"args\": None}, \n", + " #{\"type\": \"RandomResizeCrop\", \"args\": None}, \n", + " {\"type\": \"RandomRotation\", \"args\": {\"degrees\": 0.8, \"fill\": 0}}, \n", + " {\"type\": \"ColorJitter\", \"args\": {\"brightness\": 0.5, \"contrast\": 0.5, \"saturation\": 0.5, \"hue\": 0.5}}, \n", + " {\"type\": \"ToTensor\", \"args\": None}, \n", + " {\"type\": \"Normalize\", \"args\": {\"mean\": [0.912], \"std\": 0.168}},\n", + " #{\"type\": \"RandomAffine\", \"args\": {\"degrees\": [-0.25, 0.25], \"scale\": [0.98, 1.0]}}\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "target_transforms = [\n", + " {\"type\": \"ToLower\", \"args\": None},\n", + " {\"type\": \"ToCharcters\", \"args\": {\"pad_token\": \"_\", \"eos_token\": \"\"}},\n", + " {\"type\": \"ToWordPieces\", \"args\": {\n", + " \"num_features\": 64, \n", + " \"tokens\": \"iamdb_1kwp_tokens_1000.txt\", \n", + " \"lexicon\": \"iamdb_1kwp_lex_1000.txt\",\n", + " \"use_words\": False,\n", + " \"prepend_wordsep\": False,\n", + " }\n", + " }\n", + " \n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.datasets.transforms import ToText" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-02-24 21:43:47.687 | DEBUG | text_recognizer.datasets.transforms:__init__:201 - Using data dir: /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/raw/iam/iamdb\n" + ] + } + ], + "source": [ + "to_text = ToText(\n", + " num_features= 64, \n", + " tokens=\"iamdb_1kwp_tokens_1000.txt\", \n", + " lexicon=\"iamdb_1kwp_lex_1000.txt\",\n", + " use_words=False,\n", + " prepend_wordsep= False,)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-02-24 21:42:02.700 | DEBUG | text_recognizer.datasets.transforms:__init__:201 - Using data dir: /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/raw/iam/iamdb\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IAM Lines Dataset\n", + "Number classes: 54\n", + "Mapping: {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'a', 11: 'b', 12: 'c', 13: 'd', 14: 'e', 15: 'f', 16: 'g', 17: 'h', 18: 'i', 19: 'j', 20: 'k', 21: 'l', 22: 'm', 23: 'n', 24: 'o', 25: 'p', 26: 'q', 27: 'r', 28: 's', 29: 't', 30: 'u', 31: 'v', 32: 'w', 33: 'x', 34: 'y', 35: 'z', 36: ' ', 37: '!', 38: '\"', 39: '#', 40: '&', 41: \"'\", 42: '(', 43: ')', 44: '*', 45: '+', 46: ',', 47: '-', 48: '.', 49: '/', 50: ':', 51: ';', 52: '?', 53: '_'}\n", + "Data: (1861, 28, 952)\n", + "Targets: (1861, 97)\n", + "\n" + ] + } + ], + "source": [ + "dataset = IamLinesDataset(train=False, pad_token=\"_\", transform=transform, target_transform=target_transforms, lower=True)\n", + "dataset.load_or_generate_data()\n", + "print(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "but▁since▁starting▁salaries▁would▁depend▁on▁grade▁a\n", + "or▁b▁in▁the▁finals▁next▁may,▁and▁since▁mating\n", + "prospects▁would▁depend▁upon▁salaries,▁scholarship▁for\n", + "these▁fine▁young▁people▁was▁closely▁geared▁to\n", + "economic▁and▁biological▁ends▁which,▁essentially,\n", + "were▁really▁means.▁so,▁seeing▁them▁revolve▁in\n", + "circles,▁harry▁had▁the▁feeling▁that▁moke▁(or▁what\n", + "moke▁consciously▁or▁unconsciously▁symbolised,▁any-\n", + "way▁in▁harry's▁mind)▁had▁these▁splendid▁young\n", + "people▁by▁the▁short▁hairs,▁and▁was▁diverting▁them▁...\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for i in range(190, 200):\n", + " plt.figure(figsize=(20, 20))\n", + " plt.xticks([])\n", + " plt.yticks([])\n", + " data, target = dataset[i]\n", + "# print(target)\n", + " print(to_text(target))\n", + "# target = [x - 26 if x > 35 else x for x in target]\n", + "# sentence = convert_y_label_to_string(target, dataset) \n", + "# print(target)\n", + "# plt.title(sentence)\n", + " plt.imshow(data.squeeze(0).numpy(), cmap='gray')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target.tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset.target_transform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.transducer import load_transducer_loss, Transducer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t, i =load_transducer_loss(64, \n", + " 0,\n", + " \"iamdb_1kwp_tokens_1000.txt\", \n", + " \"iamdb_1kwp_lex_1000.txt\",\n", + " \"1kwp_prune_0_0_optblank.bin\",\n", + " \"optional\",\n", + " False,\n", + " False,\n", + " False,\n", + " None,\n", + " \"mean\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t(target, target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/notebooks/intersection.pdf b/src/notebooks/intersection.pdf new file mode 100644 index 0000000..c425a9f Binary files /dev/null and b/src/notebooks/intersection.pdf differ diff --git a/src/tasks/build_transitions.py b/src/tasks/build_transitions.py index b12c9bc..91f8c1a 100644 --- a/src/tasks/build_transitions.py +++ b/src/tasks/build_transitions.py @@ -9,7 +9,7 @@ Most code stolen from here: import collections import itertools from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import click import gtn @@ -18,7 +18,7 @@ from loguru import logger START_IDX = -1 END_IDX = -2 -WORDSEP = "_" +WORDSEP = "▁" def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: @@ -27,7 +27,7 @@ def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: ngram = len(ngrams) state_to_node = {} - def get_node(state: Optional[List]) -> gtn.node: + def get_node(state: Optional[List]) -> Any: node = state_to_node.get(state, None) if node is not None: diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py index f605920..2ac0e2c 100644 --- a/src/tasks/make_wordpieces.py +++ b/src/tasks/make_wordpieces.py @@ -30,7 +30,7 @@ def iamdb_pieces( user_symbols=["/"], # added so token is in the output set ) - vocab = sorted(set(w for t in text for w in t.split("_") if w)) + vocab = sorted(set(w for t in text for w in t.split("▁") if w)) if "move" not in vocab: raise RuntimeError("`MOVE` not in vocab") diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py index 5a5136c..a93eb00 100644 --- a/src/text_recognizer/datasets/iam_preprocessor.py +++ b/src/text_recognizer/datasets/iam_preprocessor.py @@ -59,7 +59,7 @@ class Preprocessor: use_words: bool = False, prepend_wordsep: bool = False, ) -> None: - self.wordsep = "_" + self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 60987e0..b6a48f5 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -1,6 +1,10 @@ """Transforms for PyTorch datasets.""" +from abc import abstractmethod +from pathlib import Path import random +from typing import Any, Optional, Union +from loguru import logger import numpy as np from PIL import Image import torch @@ -18,6 +22,7 @@ from torchvision.transforms import ( ToTensor, ) +from text_recognizer.datasets.iam_preprocessor import Preprocessor from text_recognizer.datasets.util import EmnistMapper @@ -145,3 +150,117 @@ class ToLower: """Corrects index value in target tensor.""" device = target.device return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) + + +class ToCharcters: + """Converts integers to characters.""" + + def __init__( + self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True + ) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=lower, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, lower=lower + ) + + def __call__(self, y: Tensor) -> str: + """Converts a Tensor to a str.""" + return ( + "".join([self.emnist_mapper(int(i)) for i in y]) + .strip("_") + .replace(" ", "▁") + ) + + +class WordPieces: + """Abstract transform for word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + self.preprocessor = Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + """Transforms input.""" + ... + + +class ToWordPieces(WordPieces): + """Transforms str to word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, line: str) -> Tensor: + """Transforms str to word pieces.""" + return self.preprocessor.to_index(line) + + +class ToText(WordPieces): + """Takes word pieces and converts them to text.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, x: Tensor) -> str: + """Converts tensor to text.""" + return self.preprocessor.to_text(x.tolist()) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index bac5d28..1521355 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -8,7 +8,7 @@ from .lenet import LeNet from .metrics import accuracy, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder -from .transducer import TDS2d +from .transducer import load_transducer_loss, TDS2d from .transformer import Transformer from .unet import UNet from .util import sliding_window @@ -28,6 +28,7 @@ __all__ = [ "greedy_decoder", "MLP", "LeNet", + "load_transducer_loss", "ResidualNetwork", "ResidualNetworkEncoder", "sliding_window", diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 7133c26..a2d7926 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -112,11 +112,11 @@ class CNNTransformer(nn.Module): if self.max_pool is not None: src = self.max_pool(src) - if self.adaptive_pool is not None: + if self.adaptive_pool is not None and len(src.shape) == 4: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) - else: + elif len(src.shape) == 4: src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py index fdd6662..8c19a01 100644 --- a/src/text_recognizer/networks/transducer/__init__.py +++ b/src/text_recognizer/networks/transducer/__init__.py @@ -1,2 +1,3 @@ """Transducer modules.""" from .tds_conv import TDS2d +from .transducer import load_transducer_loss, Transducer diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py index 018caf2..5fb8ba9 100644 --- a/src/text_recognizer/networks/transducer/tds_conv.py +++ b/src/text_recognizer/networks/transducer/tds_conv.py @@ -136,8 +136,10 @@ class TDS2d(nn.Module): self.tds = None self.fc = None - def _build_network(self) -> None: + self._build_network() + def _build_network(self) -> None: + in_channels = self.in_channels modules = [] stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) if self.input_dim % stride_h: @@ -151,7 +153,7 @@ class TDS2d(nn.Module): modules.extend( [ nn.Conv2d( - in_channels=self.in_channels, + in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), @@ -173,12 +175,10 @@ class TDS2d(nn.Module): ) ) - self.in_channels = out_channels + in_channels = out_channels self.tds = nn.Sequential(*modules) - self.fc = nn.Linear( - self.in_channels * self.input_dim // stride_h, self.output_dim - ) + self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) def forward(self, x: Tensor) -> Tensor: """Forward pass. @@ -193,6 +193,9 @@ class TDS2d(nn.Module): Tensor: Output tensor. """ + if len(x.shape) == 4: + x = x.squeeze(1) # Squeeze the channel dim away. + B, H, W = x.shape x = rearrange( x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py new file mode 100644 index 0000000..cadcecc --- /dev/null +++ b/src/text_recognizer/networks/transducer/test.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +from text_recognizer.networks.transducer import load_transducer_loss, Transducer +import unittest + + +class TestTransducer(unittest.TestCase): + def test_viterbi(self): + T = 5 + N = 4 + B = 2 + + # fmt: off + emissions1 = torch.tensor(( + 0, 4, 0, 1, + 0, 2, 1, 1, + 0, 0, 0, 2, + 0, 0, 0, 2, + 8, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + emissions2 = torch.tensor(( + 0, 2, 1, 7, + 0, 2, 9, 1, + 0, 0, 0, 2, + 0, 0, 5, 2, + 1, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + # fmt: on + + # Test without blank: + labels = [[1, 3, 0], [3, 2, 3, 2, 3]] + transducer = Transducer( + tokens=["a", "b", "c", "d"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, + blank="none", + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + # Test with blank without repeats: + labels = [[1, 0], [2, 2]] + transducer = Transducer( + tokens=["a", "b", "c"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2}, + blank="optional", + allow_repeats=False, + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py new file mode 100644 index 0000000..d7e3d08 --- /dev/null +++ b/src/text_recognizer/networks/transducer/transducer.py @@ -0,0 +1,410 @@ +"""Transducer and the transducer loss function.py + +Stolen from: + https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py + +""" +from pathlib import Path +import itertools +from typing import Dict, List, Optional, Union, Tuple + +from loguru import logger +import gtn +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.datasets.iam_preprocessor import Preprocessor + + +def make_scalar_graph(weight) -> gtn.Graph: + scalar = gtn.Graph() + scalar.add_node(True) + scalar.add_node(False, True) + scalar.add_arc(0, 1, 0, 0, weight) + return scalar + + +def make_chain_graph(sequence) -> gtn.Graph: + graph = gtn.Graph(False) + graph.add_node(True) + for i, s in enumerate(sequence): + graph.add_node(False, i == (len(sequence) - 1)) + graph.add_arc(i, i + 1, s) + return graph + + +def make_transitions_graph( + ngram: int, num_tokens: int, calc_grad: bool = False +) -> gtn.Graph: + transitions = gtn.Graph(calc_grad) + transitions.add_node(True, ngram == 1) + + state_map = {(): 0} + + # First build transitions which include : + for n in range(1, ngram): + for state in itertools.product(range(num_tokens), repeat=n): + in_idx = state_map[state[:-1]] + out_idx = transitions.add_node(False, ngram == 1) + state_map[state] = out_idx + transitions.add_arc(in_idx, out_idx, state[-1]) + + for state in itertools.product(range(num_tokens), repeat=ngram): + state_idx = state_map[state[:-1]] + new_state_idx = state_map[state[1:]] + # p(state[-1] | state[:-1]) + transitions.add_arc(state_idx, new_state_idx, state[-1]) + + if ngram > 1: + # Build transitions which include : + end_idx = transitions.add_node(False, True) + for in_idx in range(end_idx): + transitions.add_arc(in_idx, end_idx, gtn.epsilon) + + return transitions + + +def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: + """Constructs a graph which transduces letters to word pieces.""" + graph = gtn.Graph(False) + graph.add_node(True, True) + for i, wp in enumerate(word_pieces): + prev = 0 + for l in wp[:-1]: + n = graph.add_node() + graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) + prev = n + graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) + graph.arc_sort() + return graph + + +def make_token_graph( + token_list: List, blank: str = "none", allow_repeats: bool = True +) -> gtn.Graph: + """Constructs a graph with all the individual token transition models.""" + if not allow_repeats and blank != "optional": + raise ValueError("Must use blank='optional' if disallowing repeats.") + + ntoks = len(token_list) + graph = gtn.Graph(False) + + # Creating nodes + graph.add_node(True, True) + for i in range(ntoks): + # We can consume one or more consecutive word + # pieces for each emission: + # E.g. [ab, ab, ab] transduces to [ab] + graph.add_node(False, blank != "forced") + + if blank != "none": + graph.add_node() + + # Creating arcs + if blank != "none": + # Blank index is assumed to be last (ntoks) + graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) + graph.add_arc(ntoks + 1, 0, gtn.epsilon) + + for i in range(ntoks): + graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) + graph.add_arc(i + 1, i + 1, i, gtn.epsilon) + + if allow_repeats: + if blank == "forced": + # Allow transitions from token to blank only + graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) + else: + # Allow transition from token to blank and all other tokens + graph.add_arc(i + 1, 0, gtn.epsilon) + + else: + # allow transitions to blank and all other tokens except the same token + graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) + for j in range(ntoks): + if i != j: + graph.add_arc(i + 1, j + 1, j, j) + + return graph + + +class TransducerLossFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + targets, + tokens, + lexicon, + transition_params=None, + transitions=None, + reduction="none", + ) -> Tensor: + B, T, C = inputs.shape + + losses = [None] * B + emissions_graphs = [None] * B + + if transitions is not None: + if transition_params is None: + raise ValueError("Specified transitions, but not transition params.") + + cpu_data = transition_params.cpu().contiguous() + transitions.set_weights(cpu_data.data_ptr()) + transitions.calc_grad = transition_params.requires_grad + transitions.zero_grad() + + def process(b: int) -> None: + # Create emission graph: + emissions = gtn.linear_graph(T, C, inputs.requires_grad) + cpu_data = inputs[b].cpu().contiguous() + emissions.set_weights(cpu_data.data_ptr()) + target = make_chain_graph(targets[b]) + target.arc_sort(True) + + # Create token tot grapheme decomposition graph + tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) + tokens_target.arc_sort() + + # Create alignment graph: + aligments = gtn.project_input( + gtn.remove(gtn.compose(tokens, tokens_target)) + ) + aligments.arc_sort() + + # Add transitions scores: + if transitions is not None: + aligments = gtn.intersect(transitions, aligments) + aligments.arc_sort() + + loss = gtn.forward_score(gtn.intersect(emissions, aligments)) + + # Normalize if needed: + if transitions is not None: + norm = gtn.forward_score(gtn.intersect(emissions, transitions)) + loss = gtn.subtract(loss, norm) + + losses[b] = gtn.negate(loss) + + # Save for backward: + if emissions.calc_grad: + emissions_graphs[b] = emissions + + gtn.parallel_for(process, range(B)) + + ctx.graphs = (losses, emissions_graphs, transitions) + ctx.input_shape = inputs.shape + + # Optionally reduce by target length + if reduction == "mean": + scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] + else: + scales = [1.0] * B + + ctx.scales = scales + + loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) + return torch.mean(loss.to(inputs.device)) + + @staticmethod + def backward(ctx, grad_output) -> Tuple: + losses, emissions_graphs, transitions = ctx.graphs + scales = ctx.scales + + B, T, C = ctx.input_shape + calc_emissions = ctx.needs_input_grad[0] + input_grad = torch.empty((B, T, C)) if calc_emissions else None + + def process(b: int) -> None: + scale = make_scalar_graph(scales[b]) + gtn.backward(losses[b], scale) + emissions = emissions_graphs[b] + if calc_emissions: + grad = emissions.grad().weights_to_numpy() + input_grad[b] = torch.tensor(grad).view(1, T, C) + + gtn.parallel_for(process, range(B)) + + if calc_emissions: + input_grad = input_grad.to(grad_output.device) + input_grad *= grad_output / B + + if ctx.needs_input_grad[4]: + grad = transitions.grad().weights_to_numpy() + transition_grad = torch.tensor(grad).to(grad_output.device) + transition_grad *= grad_output / B + else: + transition_grad = None + + return ( + input_grad, + None, # target + None, # tokens + None, # lexicon + transition_grad, # transition params + None, # transitions graph + None, + ) + + +TransducerLoss = TransducerLossFunction.apply + + +class Transducer(nn.Module): + def __init__( + self, + tokens: List, + graphemes_to_idx: Dict, + ngram: int = 0, + transitions: str = None, + blank: str = "none", + allow_repeats: bool = True, + reduction: str = "none", + ) -> None: + """A generic transducer loss function. + + Args: + tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) + representing the output tokens of the model (e.g. letters, + word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] + could be a list of sub-word tokens. + graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. + "a", "b", ..) to their corresponding integer index. + ngram (int) : Order of the token-level transition model. If `ngram=0` + then no transition model is used. + blank (string) : Specifies the usage of blank token + 'none' - do not use blank token + 'optional' - allow an optional blank inbetween tokens + 'forced' - force a blank inbetween tokens (also referred to as garbage token) + allow_repeats (boolean) : If false, then we don't allow paths with + consecutive tokens in the alignment graph. This keeps the graph + unambiguous in the sense that the same input cannot transduce to + different outputs. + """ + super().__init__() + if blank not in ["optional", "forced", "none"]: + raise ValueError( + "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" + ) + self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) + self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) + self.ngram = ngram + if ngram > 0 and transitions is not None: + raise ValueError("Only one of ngram and transitions may be specified") + + if ngram > 0: + transitions = make_transitions_graph( + ngram, len(tokens) + int(blank != "none"), True + ) + + if transitions is not None: + self.transitions = transitions + self.transitions.arc_sort() + self.transitions_params = nn.Parameter( + torch.zeros(self.transitions.num_arcs()) + ) + else: + self.transitions = None + self.transitions_params = None + self.reduction = reduction + + def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: + TransducerLoss( + inputs, + targets, + self.tokens, + self.lexicon, + self.transitions_params, + self.transitions, + self.reduction, + ) + + def viterbi(self, outputs: Tensor) -> List[Tensor]: + B, T, C = outputs.shape + + if self.transitions is not None: + cpu_data = self.transition_params.cpu().contiguous() + self.transitions.set_weights(cpu_data.data_ptr()) + self.transitions.calc_grad = False + + self.tokens.arc_sort() + + paths = [None] * B + + def process(b: int) -> None: + emissions = gtn.linear_graph(T, C, False) + cpu_data = outputs[b].cpu().contiguous() + emissions.set_weights(cpu_data.data_ptr()) + + if self.transitions is not None: + full_graph = gtn.intersect(emissions, self.transitions) + else: + full_graph = emissions + + # Find the best path and remove back-off arcs: + path = gtn.remove(gtn.viterbi_path(full_graph)) + + # Left compose the viterbi path with the "aligment to token" + # transducer to get the outputs: + path = gtn.compose(path, self.tokens) + + # When there are ambiguous paths (allow_repeats is true), we take + # the shortest: + path = gtn.viterbi_path(path) + path = gtn.remove(gtn.project_output(path)) + paths[b] = path.labels_to_list() + + gtn.parallel_for(process, range(B)) + predictions = [torch.IntTensor(path) for path in paths] + return predictions + + +def load_transducer_loss( + num_features: int, + ngram: int, + tokens: str, + lexicon: str, + transitions: str, + blank: str, + allow_repeats: bool, + prepend_wordsep: bool = False, + use_words: bool = False, + data_dir: Optional[Union[str, Path]] = None, + reduction: str = "mean", +) -> Tuple[Transducer, int]: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + if transitions is not None: + transitions = gtn.load(str(processed_path / transitions)) + + preprocessor = Preprocessor( + data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + ) + + num_tokens = preprocessor.num_tokens + + criterion = Transducer( + preprocessor.tokens, + preprocessor.graphemes_to_index, + ngram=ngram, + transitions=transitions, + blank=blank, + allow_repeats=allow_repeats, + reduction=reduction, + ) + + return criterion, num_tokens + int(blank != "none") -- cgit v1.2.3-70-g09d2