summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.flake81
-rw-r--r--.pre-commit-config.yaml4
-rw-r--r--.pytype/imports/default.pyi2
-rw-r--r--.vim/coc-settings.json6
-rw-r--r--noxfile.py13
-rw-r--r--poetry.lock1119
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb1361
-rw-r--r--src/notebooks/07-try-gtn.ipynb49
-rw-r--r--src/notebooks/Untitled.ipynb385
-rw-r--r--src/notebooks/intersection.pdfbin0 -> 10154 bytes
-rw-r--r--src/tasks/build_transitions.py6
-rw-r--r--src/tasks/make_wordpieces.py2
-rw-r--r--src/text_recognizer/datasets/iam_preprocessor.py2
-rw-r--r--src/text_recognizer/datasets/transforms.py119
-rw-r--r--src/text_recognizer/networks/__init__.py3
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py4
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py1
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py15
-rw-r--r--src/text_recognizer/networks/transducer/test.py60
-rw-r--r--src/text_recognizer/networks/transducer/transducer.py410
20 files changed, 1819 insertions, 1743 deletions
diff --git a/.flake8 b/.flake8
index eff48a6..b00f63b 100644
--- a/.flake8
+++ b/.flake8
@@ -7,3 +7,4 @@ application-import-names = text_recognizer,tests
import-order-style = google
docstring-convention = google
per-file-ignores = tests/*:S101,tests/*:S106,src/text_recognizer/datasets/*:S110,src/training/callbacks/*:B006,src/tasks/build_transitions.py:C901
+exclude = src/text_recognizer/networks/transducer/*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ea3565b..8fed8e5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -9,11 +9,11 @@ repos:
hooks:
- id: black
name: black
- entry: poetry run black
+ entry: black
language: system
types: [python]
- id: flake8
name: flake8
- entry: poetry run flake8
+ entry: flake8
language: system
types: [python]
diff --git a/.pytype/imports/default.pyi b/.pytype/imports/default.pyi
index 7bb11b4..024e962 100644
--- a/.pytype/imports/default.pyi
+++ b/.pytype/imports/default.pyi
@@ -1,3 +1,3 @@
-
from typing import Any
+
def __getattr__(name) -> Any: ...
diff --git a/.vim/coc-settings.json b/.vim/coc-settings.json
new file mode 100644
index 0000000..ce08b20
--- /dev/null
+++ b/.vim/coc-settings.json
@@ -0,0 +1,6 @@
+{
+ "python.linting.pylintEnabled": false,
+ "python.linting.flake8Enabled": true,
+ "python.linting.enabled": true,
+ "python.formatting.provider": "black"
+}
diff --git a/noxfile.py b/noxfile.py
index d1a8d1b..60c3923 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -33,6 +33,7 @@ def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> Non
"export",
"--dev",
"--format=requirements.txt",
+ "--without-hashes",
f"--output={requirements.name}",
external=True,
)
@@ -47,7 +48,7 @@ def black(session: Session) -> None:
session.run("black", *args)
-@nox.session(python=["3.8", "3.7"])
+@nox.session(python=["3.8"])
def lint(session: Session) -> None:
"""Lint using flake8."""
args = session.posargs or locations
@@ -82,7 +83,7 @@ def safety(session: Session) -> None:
session.run("safety", "check", f"--file={requirements.name}", "--full-report")
-@nox.session(python=["3.8", "3.7"])
+@nox.session(python=["3.8"])
def mypy(session: Session) -> None:
"""Type-check using mypy."""
args = session.posargs or locations
@@ -90,7 +91,7 @@ def mypy(session: Session) -> None:
session.run("mypy", *args)
-@nox.session(python="3.7")
+@nox.session(python="3.8")
def pytype(session: Session) -> None:
"""Type-check using pytype."""
args = session.posargs or ["--disable=import-error", *locations]
@@ -98,7 +99,7 @@ def pytype(session: Session) -> None:
session.run("pytype", *args)
-@nox.session(python=["3.8", "3.7"])
+@nox.session(python=["3.8"])
def tests(session: Session) -> None:
"""Run the test suite."""
args = session.posargs or ["--cov", "-m", "not e2e"]
@@ -109,7 +110,7 @@ def tests(session: Session) -> None:
session.run("pytest", *args)
-@nox.session(python=["3.8", "3.7"])
+@nox.session(python=["3.8"])
def typeguard(session: Session) -> None:
"""Runtime type checking using Typeguard."""
args = session.posargs or ["-m", "not e2e"]
@@ -118,7 +119,7 @@ def typeguard(session: Session) -> None:
session.run("pytest", f"--typeguard-packages={package}", *args)
-@nox.session(python=["3.8", "3.7"])
+@nox.session(python=["3.8"])
def xdoctest(session: Session) -> None:
"""Run examples with xdoctest."""
args = session.posargs or ["all"]
diff --git a/poetry.lock b/poetry.lock
index 7f715d8..72da168 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,117 +1,115 @@
[[package]]
-category = "main"
-description = "A configurable sidebar-enabled Sphinx theme"
name = "alabaster"
+version = "0.7.12"
+description = "A configurable sidebar-enabled Sphinx theme"
+category = "main"
optional = false
python-versions = "*"
-version = "0.7.12"
[[package]]
-category = "dev"
-description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
name = "appdirs"
+version = "1.4.4"
+description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
+category = "dev"
optional = false
python-versions = "*"
-version = "1.4.4"
[[package]]
-category = "main"
-description = "Disable App Nap on OS X 10.9"
-marker = "sys_platform == \"darwin\" or platform_system == \"Darwin\" or python_version >= \"3.3\" and sys_platform == \"darwin\""
name = "appnope"
+version = "0.1.0"
+description = "Disable App Nap on OS X 10.9"
+category = "main"
optional = false
python-versions = "*"
-version = "0.1.0"
[[package]]
-category = "main"
-description = "The secure Argon2 password hashing algorithm."
name = "argon2-cffi"
+version = "20.1.0"
+description = "The secure Argon2 password hashing algorithm."
+category = "main"
optional = false
python-versions = "*"
-version = "20.1.0"
[package.dependencies]
cffi = ">=1.0.0"
six = "*"
[package.extras]
-dev = ["coverage (>=5.0.2)", "hypothesis", "pytest", "sphinx", "wheel", "pre-commit"]
+dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest", "sphinx", "wheel", "pre-commit"]
docs = ["sphinx"]
-tests = ["coverage (>=5.0.2)", "hypothesis", "pytest"]
+tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"]
[[package]]
-category = "main"
-description = "Async generators and context managers for Python 3.5+"
name = "async-generator"
+version = "1.10"
+description = "Async generators and context managers for Python 3.5+"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.10"
[[package]]
-category = "main"
-description = "Atomic file writes."
-marker = "sys_platform == \"win32\""
name = "atomicwrites"
+version = "1.4.0"
+description = "Atomic file writes."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "1.4.0"
[[package]]
-category = "main"
-description = "Classes Without Boilerplate"
name = "attrs"
+version = "20.3.0"
+description = "Classes Without Boilerplate"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "20.3.0"
[package.extras]
-dev = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"]
+dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"]
docs = ["furo", "sphinx", "zope.interface"]
-tests = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
-tests_no_zope = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"]
+tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
+tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"]
[[package]]
-category = "main"
-description = "Internationalization utilities"
name = "babel"
+version = "2.9.0"
+description = "Internationalization utilities"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "2.9.0"
[package.dependencies]
pytz = ">=2015.7"
[[package]]
-category = "main"
-description = "Specifications for callback functions passed in to an API"
name = "backcall"
+version = "0.2.0"
+description = "Specifications for callback functions passed in to an API"
+category = "main"
optional = false
python-versions = "*"
-version = "0.2.0"
[[package]]
-category = "dev"
-description = "Security oriented static analyser for python code."
name = "bandit"
+version = "1.6.2"
+description = "Security oriented static analyser for python code."
+category = "dev"
optional = false
python-versions = "*"
-version = "1.6.2"
[package.dependencies]
+colorama = {version = ">=0.3.9", markers = "platform_system == \"Windows\""}
GitPython = ">=1.0.1"
PyYAML = ">=3.13"
-colorama = ">=0.3.9"
six = ">=1.10.0"
stevedore = ">=1.20.0"
[[package]]
-category = "dev"
-description = "The uncompromising code formatter."
name = "black"
+version = "19.10b0"
+description = "The uncompromising code formatter."
+category = "dev"
optional = false
python-versions = ">=3.6"
-version = "19.10b0"
[package.dependencies]
appdirs = "*"
@@ -126,12 +124,12 @@ typed-ast = ">=1.4.0"
d = ["aiohttp (>=3.3.2)", "aiohttp-cors"]
[[package]]
-category = "main"
-description = "An easy safelist-based HTML-sanitizing tool."
name = "bleach"
+version = "3.2.1"
+description = "An easy safelist-based HTML-sanitizing tool."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "3.2.1"
[package.dependencies]
packaging = "*"
@@ -139,146 +137,143 @@ six = ">=1.9.0"
webencodings = "*"
[[package]]
-category = "dev"
-description = "A thin, practical wrapper around terminal coloring, styling, and positioning"
name = "blessings"
+version = "1.7"
+description = "A thin, practical wrapper around terminal coloring, styling, and positioning"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "1.7"
[package.dependencies]
six = "*"
[[package]]
-category = "main"
-description = "When they're not builtins, they're boltons."
name = "boltons"
+version = "20.2.1"
+description = "When they're not builtins, they're boltons."
+category = "main"
optional = false
python-versions = "*"
-version = "20.2.1"
[[package]]
-category = "main"
-description = "Python package for providing Mozilla's CA Bundle."
name = "certifi"
+version = "2020.11.8"
+description = "Python package for providing Mozilla's CA Bundle."
+category = "main"
optional = false
python-versions = "*"
-version = "2020.11.8"
[[package]]
-category = "main"
-description = "Foreign Function Interface for Python calling C code."
name = "cffi"
+version = "1.14.3"
+description = "Foreign Function Interface for Python calling C code."
+category = "main"
optional = false
python-versions = "*"
-version = "1.14.3"
[package.dependencies]
pycparser = "*"
[[package]]
-category = "main"
-description = "Universal encoding detector for Python 2 and 3"
name = "chardet"
+version = "3.0.4"
+description = "Universal encoding detector for Python 2 and 3"
+category = "main"
optional = false
python-versions = "*"
-version = "3.0.4"
[[package]]
-category = "main"
-description = "Composable command line interface toolkit"
name = "click"
+version = "7.1.2"
+description = "Composable command line interface toolkit"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "7.1.2"
[[package]]
-category = "main"
-description = "Cross-platform colored terminal text."
-marker = "sys_platform == \"win32\" or platform_system == \"Windows\" or python_version >= \"3.3\" and sys_platform == \"win32\""
name = "colorama"
+version = "0.4.4"
+description = "Cross-platform colored terminal text."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "0.4.4"
[[package]]
-category = "main"
-description = "Updated configparser from Python 3.8 for Python 2.6+."
name = "configparser"
+version = "5.0.1"
+description = "Updated configparser from Python 3.8 for Python 2.6+."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "5.0.1"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"]
-testing = ["pytest (>=3.5,<3.7.3 || >3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "jaraco.test (>=3.2.0)", "pytest-black (>=0.3.7)", "pytest-mypy"]
+testing = ["pytest (>=3.5,!=3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "jaraco.test (>=3.2.0)", "pytest-black (>=0.3.7)", "pytest-mypy"]
[[package]]
-category = "dev"
-description = "Code coverage measurement for Python"
name = "coverage"
+version = "5.3"
+description = "Code coverage measurement for Python"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
-version = "5.3"
[package.dependencies]
-[package.dependencies.toml]
-optional = true
-version = "*"
+toml = {version = "*", optional = true, markers = "extra == \"toml\""}
[package.extras]
toml = ["toml"]
[[package]]
-category = "main"
-description = "Composable style cycles"
name = "cycler"
+version = "0.10.0"
+description = "Composable style cycles"
+category = "main"
optional = false
python-versions = "*"
-version = "0.10.0"
[package.dependencies]
six = "*"
[[package]]
-category = "main"
-description = "A utility for ensuring Google-style docstrings stay up to date with the source code."
name = "darglint"
+version = "1.5.6"
+description = "A utility for ensuring Google-style docstrings stay up to date with the source code."
+category = "main"
optional = false
python-versions = ">=3.6,<4.0"
-version = "1.5.6"
[[package]]
-category = "main"
-description = "A backport of the dataclasses module for Python 3.6"
name = "dataclasses"
+version = "0.6"
+description = "A backport of the dataclasses module for Python 3.6"
+category = "main"
optional = false
python-versions = "*"
-version = "0.6"
[[package]]
-category = "main"
-description = "Decorators for Humans"
name = "decorator"
+version = "4.4.2"
+description = "Decorators for Humans"
+category = "main"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
-version = "4.4.2"
[[package]]
-category = "main"
-description = "XML bomb protection for Python stdlib modules"
name = "defusedxml"
+version = "0.6.0"
+description = "XML bomb protection for Python stdlib modules"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "0.6.0"
[[package]]
-category = "main"
-description = "Deserialize to objects while staying DRY"
name = "desert"
+version = "2020.11.18"
+description = "Deserialize to objects while staying DRY"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "2020.11.18"
[package.dependencies]
attrs = "*"
@@ -290,31 +285,31 @@ dev = ["coverage", "cuvner", "marshmallow-enum", "marshmallow-union", "pytest",
test = ["coverage", "cuvner", "marshmallow-enum", "marshmallow-union", "pytest", "pytest-cov", "pytest-sphinx", "pytest-travis-fold", "tox", "importlib-metadata"]
[[package]]
-category = "main"
-description = "Python bindings for the docker credentials store API"
name = "docker-pycreds"
+version = "0.4.0"
+description = "Python bindings for the docker credentials store API"
+category = "main"
optional = false
python-versions = "*"
-version = "0.4.0"
[package.dependencies]
six = ">=1.4.0"
[[package]]
-category = "main"
-description = "Docutils -- Python Documentation Utilities"
name = "docutils"
+version = "0.16"
+description = "Docutils -- Python Documentation Utilities"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "0.16"
[[package]]
-category = "dev"
-description = "A parser for Python dependency files"
name = "dparse"
+version = "0.5.1"
+description = "A parser for Python dependency files"
+category = "dev"
optional = false
python-versions = ">=3.5"
-version = "0.5.1"
[package.dependencies]
packaging = "*"
@@ -325,28 +320,28 @@ toml = "*"
pipenv = ["pipenv"]
[[package]]
-category = "main"
-description = "A new flavour of deep learning operations"
name = "einops"
+version = "0.3.0"
+description = "A new flavour of deep learning operations"
+category = "main"
optional = false
python-versions = "*"
-version = "0.3.0"
[[package]]
-category = "main"
-description = "Discover and load entry points from installed packages."
name = "entrypoints"
+version = "0.3"
+description = "Discover and load entry points from installed packages."
+category = "main"
optional = false
python-versions = ">=2.7"
-version = "0.3"
[[package]]
-category = "main"
-description = "the modular source code checker: pep8 pyflakes and co"
name = "flake8"
+version = "3.8.4"
+description = "the modular source code checker: pep8 pyflakes and co"
+category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
-version = "3.8.4"
[package.dependencies]
mccabe = ">=0.6.0,<0.7.0"
@@ -354,23 +349,23 @@ pycodestyle = ">=2.6.0a1,<2.7.0"
pyflakes = ">=2.2.0,<2.3.0"
[[package]]
-category = "main"
-description = "Flake8 Type Annotation Checks"
name = "flake8-annotations"
+version = "2.4.1"
+description = "Flake8 Type Annotation Checks"
+category = "main"
optional = false
python-versions = ">=3.6.1,<4.0.0"
-version = "2.4.1"
[package.dependencies]
flake8 = ">=3.7,<3.9"
[[package]]
-category = "dev"
-description = "Automated security testing with bandit and flake8."
name = "flake8-bandit"
+version = "2.1.2"
+description = "Automated security testing with bandit and flake8."
+category = "dev"
optional = false
python-versions = "*"
-version = "2.1.2"
[package.dependencies]
bandit = "*"
@@ -379,101 +374,100 @@ flake8-polyfill = "*"
pycodestyle = "*"
[[package]]
-category = "dev"
-description = "flake8 plugin to call black as a code style validator"
name = "flake8-black"
+version = "0.2.1"
+description = "flake8 plugin to call black as a code style validator"
+category = "dev"
optional = false
python-versions = "*"
-version = "0.2.1"
[package.dependencies]
black = "*"
flake8 = ">=3.0.0"
[[package]]
-category = "dev"
-description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle."
name = "flake8-bugbear"
+version = "20.1.4"
+description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle."
+category = "dev"
optional = false
python-versions = ">=3.6"
-version = "20.1.4"
[package.dependencies]
attrs = ">=19.2.0"
flake8 = ">=3.0.0"
[[package]]
-category = "main"
-description = "Extension for flake8 which uses pydocstyle to check docstrings"
name = "flake8-docstrings"
+version = "1.5.0"
+description = "Extension for flake8 which uses pydocstyle to check docstrings"
+category = "main"
optional = false
python-versions = "*"
-version = "1.5.0"
[package.dependencies]
flake8 = ">=3"
pydocstyle = ">=2.1"
[[package]]
-category = "dev"
-description = "Flake8 and pylama plugin that checks the ordering of import statements."
name = "flake8-import-order"
+version = "0.18.1"
+description = "Flake8 and pylama plugin that checks the ordering of import statements."
+category = "dev"
optional = false
python-versions = "*"
-version = "0.18.1"
[package.dependencies]
pycodestyle = "*"
-setuptools = "*"
[[package]]
-category = "dev"
-description = "Polyfill package for Flake8 plugins"
name = "flake8-polyfill"
+version = "1.0.2"
+description = "Polyfill package for Flake8 plugins"
+category = "dev"
optional = false
python-versions = "*"
-version = "1.0.2"
[package.dependencies]
flake8 = "*"
[[package]]
-category = "main"
-description = "Clean single-source support for Python 3 and 2"
name = "future"
+version = "0.18.2"
+description = "Clean single-source support for Python 3 and 2"
+category = "main"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-version = "0.18.2"
[[package]]
-category = "main"
-description = "Git Object Database"
name = "gitdb"
+version = "4.0.5"
+description = "Git Object Database"
+category = "main"
optional = false
python-versions = ">=3.4"
-version = "4.0.5"
[package.dependencies]
smmap = ">=3.0.1,<4"
[[package]]
-category = "main"
-description = "Python Git Library"
name = "gitpython"
+version = "3.1.11"
+description = "Python Git Library"
+category = "main"
optional = false
python-versions = ">=3.4"
-version = "3.1.11"
[package.dependencies]
gitdb = ">=4.0.1,<5"
[[package]]
-category = "dev"
-description = "An utility to monitor NVIDIA GPU status and usage"
name = "gpustat"
+version = "0.6.0"
+description = "An utility to monitor NVIDIA GPU status and usage"
+category = "dev"
optional = false
python-versions = "*"
-version = "0.6.0"
[package.dependencies]
blessings = ">=1.6"
@@ -485,12 +479,12 @@ six = ">=1.7"
test = ["mock (>=2.0.0)", "pytest (<5.0)"]
[[package]]
-category = "dev"
-description = "Simple Python interface for Graphviz"
name = "graphviz"
+version = "0.16"
+description = "Simple Python interface for Graphviz"
+category = "dev"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
-version = "0.16"
[package.extras]
dev = ["tox (>=3)", "flake8", "pep8-naming", "wheel", "twine"]
@@ -498,51 +492,51 @@ docs = ["sphinx (>=1.8)", "sphinx-rtd-theme"]
test = ["mock (>=3)", "pytest (>=4)", "pytest-mock (>=2)", "pytest-cov"]
[[package]]
-category = "main"
-description = "Automatic differentiation with WFSTs"
name = "gtn"
+version = "0.0.0"
+description = "Automatic differentiation with WFSTs"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "0.0.0"
[[package]]
-category = "main"
-description = "Read and write HDF5 files from Python"
name = "h5py"
+version = "2.10.0"
+description = "Read and write HDF5 files from Python"
+category = "main"
optional = false
python-versions = "*"
-version = "2.10.0"
[package.dependencies]
numpy = ">=1.7"
six = "*"
[[package]]
-category = "main"
-description = "Internationalized Domain Names in Applications (IDNA)"
name = "idna"
+version = "2.10"
+description = "Internationalized Domain Names in Applications (IDNA)"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "2.10"
[[package]]
-category = "main"
-description = "Getting image size from png/jpeg/jpeg2000/gif file"
name = "imagesize"
+version = "1.2.0"
+description = "Getting image size from png/jpeg/jpeg2000/gif file"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "1.2.0"
[[package]]
-category = "main"
-description = "IPython Kernel for Jupyter"
name = "ipykernel"
+version = "5.3.4"
+description = "IPython Kernel for Jupyter"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "5.3.4"
[package.dependencies]
-appnope = "*"
+appnope = {version = "*", markers = "platform_system == \"Darwin\""}
ipython = ">=5.0.0"
jupyter-client = "*"
tornado = ">=4.2"
@@ -552,24 +546,23 @@ traitlets = ">=4.1.0"
test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "nose"]
[[package]]
-category = "main"
-description = "IPython: Productive Interactive Computing"
name = "ipython"
+version = "7.19.0"
+description = "IPython: Productive Interactive Computing"
+category = "main"
optional = false
python-versions = ">=3.7"
-version = "7.19.0"
[package.dependencies]
-appnope = "*"
+appnope = {version = "*", markers = "sys_platform == \"darwin\""}
backcall = "*"
-colorama = "*"
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
decorator = "*"
jedi = ">=0.10"
-pexpect = ">4.3"
+pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""}
pickleshare = "*"
prompt-toolkit = ">=2.0.0,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.1.0"
pygments = "*"
-setuptools = ">=18.5"
traitlets = ">=4.2"
[package.extras]
@@ -584,56 +577,53 @@ qtconsole = ["qtconsole"]
test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.14)"]
[[package]]
-category = "main"
-description = "Vestigial utilities from IPython"
name = "ipython-genutils"
+version = "0.2.0"
+description = "Vestigial utilities from IPython"
+category = "main"
optional = false
python-versions = "*"
-version = "0.2.0"
[[package]]
-category = "dev"
-description = "IPython HTML widgets for Jupyter"
name = "ipywidgets"
+version = "7.5.1"
+description = "IPython HTML widgets for Jupyter"
+category = "dev"
optional = false
python-versions = "*"
-version = "7.5.1"
[package.dependencies]
ipykernel = ">=4.5.1"
+ipython = {version = ">=4.0.0", markers = "python_version >= \"3.3\""}
nbformat = ">=4.2.0"
traitlets = ">=4.3.1"
widgetsnbextension = ">=3.5.0,<3.6.0"
-[package.dependencies.ipython]
-python = ">=3.3"
-version = ">=4.0.0"
-
[package.extras]
test = ["pytest (>=3.6.0)", "pytest-cov", "mock"]
[[package]]
-category = "main"
-description = "An autocompletion tool for Python that can be used for text editors."
name = "jedi"
+version = "0.17.2"
+description = "An autocompletion tool for Python that can be used for text editors."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "0.17.2"
[package.dependencies]
parso = ">=0.7.0,<0.8.0"
[package.extras]
-qa = ["flake8 (3.7.9)"]
+qa = ["flake8 (==3.7.9)"]
testing = ["Django (<3.1)", "colorama", "docopt", "pytest (>=3.9.0,<5.0.0)"]
[[package]]
-category = "main"
-description = "A very fast and expressive template engine."
name = "jinja2"
+version = "2.11.2"
+description = "A very fast and expressive template engine."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "2.11.2"
[package.dependencies]
MarkupSafe = ">=0.23"
@@ -642,25 +632,24 @@ MarkupSafe = ">=0.23"
i18n = ["Babel (>=0.8)"]
[[package]]
-category = "main"
-description = "Lightweight pipelining: using Python functions as pipeline jobs."
name = "joblib"
+version = "0.17.0"
+description = "Lightweight pipelining: using Python functions as pipeline jobs."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "0.17.0"
[[package]]
-category = "main"
-description = "An implementation of JSON Schema validation for Python"
name = "jsonschema"
+version = "3.2.0"
+description = "An implementation of JSON Schema validation for Python"
+category = "main"
optional = false
python-versions = "*"
-version = "3.2.0"
[package.dependencies]
attrs = ">=17.4.0"
pyrsistent = ">=0.14.0"
-setuptools = "*"
six = ">=1.11.0"
[package.extras]
@@ -668,12 +657,12 @@ format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors
format_nongpl = ["idna", "jsonpointer (>1.13)", "webcolors", "rfc3986-validator (>0.1.0)", "rfc3339-validator"]
[[package]]
-category = "dev"
-description = "Jupyter metapackage. Install all the Jupyter components in one go."
name = "jupyter"
+version = "1.0.0"
+description = "Jupyter metapackage. Install all the Jupyter components in one go."
+category = "dev"
optional = false
python-versions = "*"
-version = "1.0.0"
[package.dependencies]
ipykernel = "*"
@@ -684,12 +673,12 @@ notebook = "*"
qtconsole = "*"
[[package]]
-category = "main"
-description = "Jupyter protocol implementation and client libraries"
name = "jupyter-client"
+version = "6.1.7"
+description = "Jupyter protocol implementation and client libraries"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "6.1.7"
[package.dependencies]
jupyter-core = ">=4.6.0"
@@ -702,12 +691,12 @@ traitlets = "*"
test = ["ipykernel", "ipython", "mock", "pytest", "pytest-asyncio", "async-generator", "pytest-timeout"]
[[package]]
-category = "dev"
-description = "Jupyter terminal console"
name = "jupyter-console"
+version = "6.2.0"
+description = "Jupyter terminal console"
+category = "dev"
optional = false
python-versions = ">=3.6"
-version = "6.2.0"
[package.dependencies]
ipykernel = "*"
@@ -720,35 +709,35 @@ pygments = "*"
test = ["pexpect"]
[[package]]
-category = "main"
-description = "Jupyter core package. A base package on which Jupyter projects rely."
name = "jupyter-core"
+version = "4.7.0"
+description = "Jupyter core package. A base package on which Jupyter projects rely."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "4.7.0"
[package.dependencies]
-pywin32 = ">=1.0"
+pywin32 = {version = ">=1.0", markers = "sys_platform == \"win32\""}
traitlets = "*"
[[package]]
-category = "main"
-description = "Pygments theme using JupyterLab CSS variables"
name = "jupyterlab-pygments"
+version = "0.1.2"
+description = "Pygments theme using JupyterLab CSS variables"
+category = "main"
optional = false
python-versions = "*"
-version = "0.1.2"
[package.dependencies]
pygments = ">=2.4.1,<3"
[[package]]
-category = "main"
-description = "Select and install a Jupyter notebook theme"
name = "jupyterthemes"
+version = "0.20.0"
+description = "Select and install a Jupyter notebook theme"
+category = "main"
optional = false
python-versions = "*"
-version = "0.20.0"
[package.dependencies]
ipython = ">=5.4.1"
@@ -758,69 +747,69 @@ matplotlib = ">=1.4.3"
notebook = ">=5.6.0"
[[package]]
-category = "main"
-description = "A fast implementation of the Cassowary constraint solver"
name = "kiwisolver"
+version = "1.3.1"
+description = "A fast implementation of the Cassowary constraint solver"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "1.3.1"
[[package]]
-category = "main"
-description = "Python LESS compiler"
name = "lesscpy"
+version = "0.14.0"
+description = "Python LESS compiler"
+category = "main"
optional = false
python-versions = "*"
-version = "0.14.0"
[package.dependencies]
ply = "*"
six = "*"
[[package]]
-category = "main"
-description = "Python logging made (stupidly) simple"
name = "loguru"
+version = "0.5.3"
+description = "Python logging made (stupidly) simple"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "0.5.3"
[package.dependencies]
-colorama = ">=0.3.4"
-win32-setctime = ">=1.0.0"
+colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
+win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"]
[[package]]
-category = "main"
-description = "Safely add untrusted strings to HTML/XML markup."
name = "markupsafe"
+version = "1.1.1"
+description = "Safely add untrusted strings to HTML/XML markup."
+category = "main"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*"
-version = "1.1.1"
[[package]]
-category = "main"
-description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
name = "marshmallow"
+version = "3.9.1"
+description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "3.9.1"
[package.extras]
-dev = ["pytest", "pytz", "simplejson", "mypy (0.790)", "flake8 (3.8.4)", "flake8-bugbear (20.1.4)", "pre-commit (>=2.4,<3.0)", "tox"]
-docs = ["sphinx (3.3.0)", "sphinx-issues (1.2.0)", "alabaster (0.7.12)", "sphinx-version-warning (1.1.2)", "autodocsumm (0.2.1)"]
-lint = ["mypy (0.790)", "flake8 (3.8.4)", "flake8-bugbear (20.1.4)", "pre-commit (>=2.4,<3.0)"]
+dev = ["pytest", "pytz", "simplejson", "mypy (==0.790)", "flake8 (==3.8.4)", "flake8-bugbear (==20.1.4)", "pre-commit (>=2.4,<3.0)", "tox"]
+docs = ["sphinx (==3.3.0)", "sphinx-issues (==1.2.0)", "alabaster (==0.7.12)", "sphinx-version-warning (==1.1.2)", "autodocsumm (==0.2.1)"]
+lint = ["mypy (==0.790)", "flake8 (==3.8.4)", "flake8-bugbear (==20.1.4)", "pre-commit (>=2.4,<3.0)"]
tests = ["pytest", "pytz", "simplejson"]
[[package]]
-category = "main"
-description = "Python plotting package"
name = "matplotlib"
+version = "3.3.3"
+description = "Python plotting package"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "3.3.3"
[package.dependencies]
cycler = ">=0.10"
@@ -831,36 +820,36 @@ pyparsing = ">=2.0.3,<2.0.4 || >2.0.4,<2.1.2 || >2.1.2,<2.1.6 || >2.1.6"
python-dateutil = ">=2.1"
[[package]]
-category = "main"
-description = "McCabe checker, plugin for flake8"
name = "mccabe"
+version = "0.6.1"
+description = "McCabe checker, plugin for flake8"
+category = "main"
optional = false
python-versions = "*"
-version = "0.6.1"
[[package]]
-category = "main"
-description = "The fastest markdown parser in pure Python"
name = "mistune"
+version = "0.8.4"
+description = "The fastest markdown parser in pure Python"
+category = "main"
optional = false
python-versions = "*"
-version = "0.8.4"
[[package]]
-category = "main"
-description = "More routines for operating on iterables, beyond itertools"
name = "more-itertools"
+version = "8.6.0"
+description = "More routines for operating on iterables, beyond itertools"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "8.6.0"
[[package]]
-category = "dev"
-description = "Optional static typing for Python"
name = "mypy"
+version = "0.770"
+description = "Optional static typing for Python"
+category = "dev"
optional = false
python-versions = ">=3.5"
-version = "0.770"
[package.dependencies]
mypy-extensions = ">=0.4.3,<0.5.0"
@@ -871,20 +860,20 @@ typing-extensions = ">=3.7.4"
dmypy = ["psutil (>=4.0)"]
[[package]]
-category = "main"
-description = "Experimental type system extensions for programs checked with the mypy typechecker."
name = "mypy-extensions"
+version = "0.4.3"
+description = "Experimental type system extensions for programs checked with the mypy typechecker."
+category = "main"
optional = false
python-versions = "*"
-version = "0.4.3"
[[package]]
-category = "main"
-description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
name = "nbclient"
+version = "0.5.1"
+description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "0.5.1"
[package.dependencies]
async-generator = "*"
@@ -899,12 +888,12 @@ sphinx = ["Sphinx (>=1.7)", "sphinx-book-theme", "mock", "moto", "myst-parser"]
test = ["codecov", "coverage", "ipython", "ipykernel", "ipywidgets", "pytest (>=4.1)", "pytest-cov (>=2.6.1)", "check-manifest", "flake8", "mypy", "tox", "bumpversion", "xmltodict", "pip (>=18.1)", "wheel (>=0.31.0)", "setuptools (>=38.6.0)", "twine (>=1.11.0)", "black"]
[[package]]
-category = "main"
-description = "Converting Jupyter Notebooks"
name = "nbconvert"
+version = "6.0.7"
+description = "Converting Jupyter Notebooks"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "6.0.7"
[package.dependencies]
bleach = "*"
@@ -922,19 +911,19 @@ testpath = "*"
traitlets = ">=4.2"
[package.extras]
-all = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (0.2.2)", "tornado (>=4.0)", "sphinx (>=1.5.1)", "sphinx-rtd-theme", "nbsphinx (>=0.2.12)", "ipython"]
+all = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (==0.2.2)", "tornado (>=4.0)", "sphinx (>=1.5.1)", "sphinx-rtd-theme", "nbsphinx (>=0.2.12)", "ipython"]
docs = ["sphinx (>=1.5.1)", "sphinx-rtd-theme", "nbsphinx (>=0.2.12)", "ipython"]
serve = ["tornado (>=4.0)"]
-test = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (0.2.2)"]
-webpdf = ["pyppeteer (0.2.2)"]
+test = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>=7)", "pyppeteer (==0.2.2)"]
+webpdf = ["pyppeteer (==0.2.2)"]
[[package]]
-category = "main"
-description = "The Jupyter Notebook format"
name = "nbformat"
+version = "5.0.8"
+description = "The Jupyter Notebook format"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "5.0.8"
[package.dependencies]
ipython-genutils = "*"
@@ -947,20 +936,20 @@ fast = ["fastjsonschema"]
test = ["fastjsonschema", "testpath", "pytest", "pytest-cov"]
[[package]]
-category = "main"
-description = "Patch asyncio to allow nested event loops"
name = "nest-asyncio"
+version = "1.4.3"
+description = "Patch asyncio to allow nested event loops"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.4.3"
[[package]]
-category = "main"
-description = "Natural Language Toolkit"
name = "nltk"
+version = "3.5"
+description = "Natural Language Toolkit"
+category = "main"
optional = false
python-versions = "*"
-version = "3.5"
[package.dependencies]
click = "*"
@@ -977,15 +966,14 @@ tgrep = ["pyparsing"]
twitter = ["twython"]
[[package]]
-category = "main"
-description = "A web-based notebook environment for interactive computing"
name = "notebook"
+version = "6.1.5"
+description = "A web-based notebook environment for interactive computing"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "6.1.5"
[package.dependencies]
-Send2Trash = "*"
argon2-cffi = "*"
ipykernel = "*"
ipython-genutils = "*"
@@ -996,6 +984,7 @@ nbconvert = "*"
nbformat = "*"
prometheus-client = "*"
pyzmq = ">=17"
+Send2Trash = "*"
terminado = ">=0.8.3"
tornado = ">=5.0"
traitlets = ">=4.2.1"
@@ -1005,161 +994,160 @@ docs = ["sphinx", "nbsphinx", "sphinxcontrib-github-alt"]
test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-exclude", "selenium", "pytest", "pytest-cov", "requests-unixsocket"]
[[package]]
-category = "main"
-description = "NumPy is the fundamental package for array computing with Python."
name = "numpy"
+version = "1.19.4"
+description = "NumPy is the fundamental package for array computing with Python."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "1.19.4"
[[package]]
-category = "dev"
-description = "Python Bindings for the NVIDIA Management Library"
name = "nvidia-ml-py3"
+version = "7.352.0"
+description = "Python Bindings for the NVIDIA Management Library"
+category = "dev"
optional = false
python-versions = "*"
-version = "7.352.0"
[[package]]
-category = "main"
-description = "A flexible configuration library"
name = "omegaconf"
+version = "2.0.5"
+description = "A flexible configuration library"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "2.0.5"
[package.dependencies]
PyYAML = ">=5.1"
typing-extensions = "*"
[[package]]
-category = "main"
-description = "Wrapper package for OpenCV python bindings."
name = "opencv-python"
+version = "4.4.0.46"
+description = "Wrapper package for OpenCV python bindings."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "4.4.0.46"
[[package]]
-category = "main"
-description = "Core utilities for Python packages"
name = "packaging"
+version = "20.4"
+description = "Core utilities for Python packages"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "20.4"
[package.dependencies]
pyparsing = ">=2.0.2"
six = "*"
[[package]]
-category = "main"
-description = "Utilities for writing pandoc filters in python"
name = "pandocfilters"
+version = "1.4.3"
+description = "Utilities for writing pandoc filters in python"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "1.4.3"
[[package]]
-category = "main"
-description = "A Python Parser"
name = "parso"
+version = "0.7.1"
+description = "A Python Parser"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "0.7.1"
[package.extras]
testing = ["docopt", "pytest (>=3.0.7)"]
[[package]]
-category = "dev"
-description = "Utility library for gitignore style pattern matching of file paths."
name = "pathspec"
+version = "0.8.1"
+description = "Utility library for gitignore style pattern matching of file paths."
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "0.8.1"
[[package]]
-category = "main"
-description = "File system general utilities"
name = "pathtools"
+version = "0.1.2"
+description = "File system general utilities"
+category = "main"
optional = false
python-versions = "*"
-version = "0.1.2"
[[package]]
-category = "dev"
-description = "Python Build Reasonableness"
name = "pbr"
+version = "5.5.1"
+description = "Python Build Reasonableness"
+category = "dev"
optional = false
python-versions = ">=2.6"
-version = "5.5.1"
[[package]]
-category = "main"
-description = "Pexpect allows easy control of interactive console applications."
-marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\""
name = "pexpect"
+version = "4.8.0"
+description = "Pexpect allows easy control of interactive console applications."
+category = "main"
optional = false
python-versions = "*"
-version = "4.8.0"
[package.dependencies]
ptyprocess = ">=0.5"
[[package]]
-category = "main"
-description = "Tiny 'shelve'-like database with concurrency support"
name = "pickleshare"
+version = "0.7.5"
+description = "Tiny 'shelve'-like database with concurrency support"
+category = "main"
optional = false
python-versions = "*"
-version = "0.7.5"
[[package]]
-category = "main"
-description = "Python Imaging Library (Fork)"
name = "pillow"
+version = "8.0.1"
+description = "Python Imaging Library (Fork)"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "8.0.1"
[[package]]
-category = "main"
-description = "plugin and hook calling mechanisms for python"
name = "pluggy"
+version = "0.13.1"
+description = "plugin and hook calling mechanisms for python"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "0.13.1"
[package.extras]
dev = ["pre-commit", "tox"]
[[package]]
-category = "main"
-description = "Python Lex & Yacc"
name = "ply"
+version = "3.11"
+description = "Python Lex & Yacc"
+category = "main"
optional = false
python-versions = "*"
-version = "3.11"
[[package]]
-category = "main"
-description = "Python client for the Prometheus monitoring system."
name = "prometheus-client"
+version = "0.9.0"
+description = "Python client for the Prometheus monitoring system."
+category = "main"
optional = false
python-versions = "*"
-version = "0.9.0"
[package.extras]
twisted = ["twisted"]
[[package]]
-category = "main"
-description = "Promises/A+ implementation for Python"
name = "promise"
+version = "2.3"
+description = "Promises/A+ implementation for Python"
+category = "main"
optional = false
python-versions = "*"
-version = "2.3"
[package.dependencies]
six = "*"
@@ -1168,126 +1156,125 @@ six = "*"
test = ["pytest (>=2.7.3)", "pytest-cov", "coveralls", "futures", "pytest-benchmark", "mock"]
[[package]]
-category = "main"
-description = "Library for building powerful interactive command lines in Python"
name = "prompt-toolkit"
+version = "3.0.8"
+description = "Library for building powerful interactive command lines in Python"
+category = "main"
optional = false
python-versions = ">=3.6.1"
-version = "3.0.8"
[package.dependencies]
wcwidth = "*"
[[package]]
-category = "main"
-description = "Protocol Buffers"
name = "protobuf"
+version = "3.14.0"
+description = "Protocol Buffers"
+category = "main"
optional = false
python-versions = "*"
-version = "3.14.0"
[package.dependencies]
six = ">=1.9"
[[package]]
-category = "main"
-description = "Cross-platform lib for process and system monitoring in Python."
name = "psutil"
+version = "5.7.3"
+description = "Cross-platform lib for process and system monitoring in Python."
+category = "main"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "5.7.3"
[package.extras]
test = ["ipaddress", "mock", "unittest2", "enum34", "pywin32", "wmi"]
[[package]]
-category = "main"
-description = "Run a subprocess in a pseudo terminal"
-marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\" or os_name != \"nt\" or python_version >= \"3.3\" and sys_platform != \"win32\" and (python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\")"
name = "ptyprocess"
+version = "0.6.0"
+description = "Run a subprocess in a pseudo terminal"
+category = "main"
optional = false
python-versions = "*"
-version = "0.6.0"
[[package]]
-category = "main"
-description = "library with cross-python path, ini-parsing, io, code, log facilities"
name = "py"
+version = "1.9.0"
+description = "library with cross-python path, ini-parsing, io, code, log facilities"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "1.9.0"
[[package]]
-category = "main"
-description = "Python style guide checker"
name = "pycodestyle"
+version = "2.6.0"
+description = "Python style guide checker"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "2.6.0"
[[package]]
-category = "main"
-description = "C parser in Python"
name = "pycparser"
+version = "2.20"
+description = "C parser in Python"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "2.20"
[[package]]
-category = "main"
-description = "Python docstring style checker"
name = "pydocstyle"
+version = "5.1.1"
+description = "Python docstring style checker"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "5.1.1"
[package.dependencies]
snowballstemmer = "*"
[[package]]
-category = "main"
-description = "passive checker of Python programs"
name = "pyflakes"
+version = "2.2.0"
+description = "passive checker of Python programs"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "2.2.0"
[[package]]
-category = "main"
-description = "Pygments is a syntax highlighting package written in Python."
name = "pygments"
+version = "2.7.2"
+description = "Pygments is a syntax highlighting package written in Python."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "2.7.2"
[[package]]
-category = "main"
-description = "Python parsing module"
name = "pyparsing"
+version = "2.4.7"
+description = "Python parsing module"
+category = "main"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-version = "2.4.7"
[[package]]
-category = "main"
-description = "Persistent/Functional/Immutable data structures"
name = "pyrsistent"
+version = "0.17.3"
+description = "Persistent/Functional/Immutable data structures"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "0.17.3"
[[package]]
-category = "main"
-description = "pytest: simple powerful testing with Python"
name = "pytest"
+version = "5.4.3"
+description = "pytest: simple powerful testing with Python"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "5.4.3"
[package.dependencies]
-atomicwrites = ">=1.0"
+atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
attrs = ">=17.4.0"
-colorama = "*"
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
more-itertools = ">=4.0.0"
packaging = "*"
pluggy = ">=0.12,<1.0"
@@ -1295,31 +1282,31 @@ py = ">=1.5.0"
wcwidth = "*"
[package.extras]
-checkqa-mypy = ["mypy (v0.761)"]
+checkqa-mypy = ["mypy (==v0.761)"]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
-category = "dev"
-description = "Pytest plugin for measuring coverage."
name = "pytest-cov"
+version = "2.10.1"
+description = "Pytest plugin for measuring coverage."
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "2.10.1"
[package.dependencies]
coverage = ">=4.4"
pytest = ">=4.6"
[package.extras]
-testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"]
+testing = ["fields", "hunter", "process-tests (==2.0.2)", "six", "pytest-xdist", "virtualenv"]
[[package]]
-category = "dev"
-description = "Thin-wrapper around the mock package for easier use with pytest"
name = "pytest-mock"
+version = "3.3.1"
+description = "Thin-wrapper around the mock package for easier use with pytest"
+category = "dev"
optional = false
python-versions = ">=3.5"
-version = "3.3.1"
[package.dependencies]
pytest = ">=5.0"
@@ -1328,34 +1315,31 @@ pytest = ">=5.0"
dev = ["pre-commit", "tox", "pytest-asyncio"]
[[package]]
-category = "main"
-description = "Extensions to the standard Python datetime module"
name = "python-dateutil"
+version = "2.8.1"
+description = "Extensions to the standard Python datetime module"
+category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
-version = "2.8.1"
[package.dependencies]
six = ">=1.5"
[[package]]
-category = "main"
-description = "Python extension for computing string edit distances and similarities."
name = "python-levenshtein"
+version = "0.12.0"
+description = "Python extension for computing string edit distances and similarities."
+category = "main"
optional = false
python-versions = "*"
-version = "0.12.0"
-
-[package.dependencies]
-setuptools = "*"
[[package]]
-category = "main"
-description = "The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch."
name = "pytorch-metric-learning"
+version = "0.9.94"
+description = "The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch."
+category = "main"
optional = false
python-versions = ">=3.0"
-version = "0.9.94"
[package.dependencies]
numpy = "*"
@@ -1370,58 +1354,56 @@ with-hooks = ["record-keeper (>=0.9.29)", "faiss-gpu (>=1.6.3)", "tensorboard"]
with-hooks-cpu = ["record-keeper (>=0.9.29)", "faiss-cpu (>=1.6.3)", "tensorboard"]
[[package]]
-category = "main"
-description = "World timezone definitions, modern and historical"
name = "pytz"
+version = "2020.4"
+description = "World timezone definitions, modern and historical"
+category = "main"
optional = false
python-versions = "*"
-version = "2020.4"
[[package]]
-category = "main"
-description = "Python for Window Extensions"
-marker = "sys_platform == \"win32\""
name = "pywin32"
+version = "300"
+description = "Python for Window Extensions"
+category = "main"
optional = false
python-versions = "*"
-version = "300"
[[package]]
-category = "main"
-description = "Python bindings for the winpty library"
-marker = "os_name == \"nt\""
name = "pywinpty"
+version = "0.5.7"
+description = "Python bindings for the winpty library"
+category = "main"
optional = false
python-versions = "*"
-version = "0.5.7"
[[package]]
-category = "main"
-description = "YAML parser and emitter for Python"
name = "pyyaml"
+version = "5.3.1"
+description = "YAML parser and emitter for Python"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "5.3.1"
[[package]]
-category = "main"
-description = "Python bindings for 0MQ"
name = "pyzmq"
+version = "20.0.0"
+description = "Python bindings for 0MQ"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "20.0.0"
[package.dependencies]
-cffi = "*"
-py = "*"
+cffi = {version = "*", markers = "implementation_name === \"pypy\""}
+py = {version = "*", markers = "implementation_name === \"pypy\""}
[[package]]
-category = "dev"
-description = "Jupyter Qt console"
name = "qtconsole"
+version = "4.7.7"
+description = "Jupyter Qt console"
+category = "dev"
optional = false
python-versions = "*"
-version = "4.7.7"
[package.dependencies]
ipykernel = ">=4.1"
@@ -1438,50 +1420,50 @@ doc = ["Sphinx (>=1.3)"]
test = ["pytest", "mock"]
[[package]]
-category = "dev"
-description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5, PyQt4 and PySide) and additional custom QWidgets."
name = "qtpy"
+version = "1.9.0"
+description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5, PyQt4 and PySide) and additional custom QWidgets."
+category = "dev"
optional = false
python-versions = "*"
-version = "1.9.0"
[[package]]
-category = "dev"
-description = "Python client for Redis key-value store"
name = "redis"
+version = "3.5.3"
+description = "Python client for Redis key-value store"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "3.5.3"
[package.extras]
hiredis = ["hiredis (>=0.1.3)"]
[[package]]
-category = "dev"
-description = "Redis locking mechanism"
name = "redlock-py"
+version = "1.0.8"
+description = "Redis locking mechanism"
+category = "dev"
optional = false
python-versions = "*"
-version = "1.0.8"
[package.dependencies]
redis = "*"
[[package]]
-category = "main"
-description = "Alternative regular expression module, to replace re."
name = "regex"
+version = "2020.11.13"
+description = "Alternative regular expression module, to replace re."
+category = "main"
optional = false
python-versions = "*"
-version = "2020.11.13"
[[package]]
-category = "main"
-description = "Python HTTP for Humans."
name = "requests"
+version = "2.25.0"
+description = "Python HTTP for Humans."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-version = "2.25.0"
[package.dependencies]
certifi = ">=2017.4.17"
@@ -1491,30 +1473,29 @@ urllib3 = ">=1.21.1,<1.27"
[package.extras]
security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"]
-socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"]
+socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"]
[[package]]
-category = "dev"
-description = "Checks installed dependencies for known vulnerabilities."
name = "safety"
+version = "1.9.0"
+description = "Checks installed dependencies for known vulnerabilities."
+category = "dev"
optional = false
python-versions = ">=3.5"
-version = "1.9.0"
[package.dependencies]
Click = ">=6.0"
dparse = ">=0.5.1"
packaging = "*"
requests = "*"
-setuptools = "*"
[[package]]
-category = "main"
-description = "A set of python modules for machine learning and data mining"
name = "scikit-learn"
+version = "0.23.2"
+description = "A set of python modules for machine learning and data mining"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "0.23.2"
[package.dependencies]
joblib = ">=0.11"
@@ -1526,39 +1507,39 @@ threadpoolctl = ">=2.0.0"
alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"]
[[package]]
-category = "main"
-description = "SciPy: Scientific Library for Python"
name = "scipy"
+version = "1.5.4"
+description = "SciPy: Scientific Library for Python"
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "1.5.4"
[package.dependencies]
numpy = ">=1.14.5"
[[package]]
-category = "main"
-description = "Send file to trash natively under Mac OS X, Windows and Linux."
name = "send2trash"
+version = "1.5.0"
+description = "Send file to trash natively under Mac OS X, Windows and Linux."
+category = "main"
optional = false
python-versions = "*"
-version = "1.5.0"
[[package]]
-category = "main"
-description = "SentencePiece python wrapper"
name = "sentencepiece"
+version = "0.1.95"
+description = "SentencePiece python wrapper"
+category = "main"
optional = false
python-versions = "*"
-version = "0.1.95"
[[package]]
-category = "main"
-description = "Python client for Sentry (https://sentry.io)"
name = "sentry-sdk"
+version = "0.19.4"
+description = "Python client for Sentry (https://sentry.io)"
+category = "main"
optional = false
python-versions = "*"
-version = "0.19.4"
[package.dependencies]
certifi = "*"
@@ -1581,56 +1562,55 @@ sqlalchemy = ["sqlalchemy (>=1.2)"]
tornado = ["tornado (>=5)"]
[[package]]
-category = "main"
-description = "A generator library for concise, unambiguous and URL-safe UUIDs."
name = "shortuuid"
+version = "1.0.1"
+description = "A generator library for concise, unambiguous and URL-safe UUIDs."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.1"
[[package]]
-category = "main"
-description = "Python 2 and 3 compatibility utilities"
name = "six"
+version = "1.15.0"
+description = "Python 2 and 3 compatibility utilities"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
-version = "1.15.0"
[[package]]
-category = "main"
-description = "A pure Python implementation of a sliding window memory map manager"
name = "smmap"
+version = "3.0.4"
+description = "A pure Python implementation of a sliding window memory map manager"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "3.0.4"
[[package]]
-category = "main"
-description = "This package provides 26 stemmers for 25 languages generated from Snowball algorithms."
name = "snowballstemmer"
+version = "2.0.0"
+description = "This package provides 26 stemmers for 25 languages generated from Snowball algorithms."
+category = "main"
optional = false
python-versions = "*"
-version = "2.0.0"
[[package]]
-category = "main"
-description = "Python documentation generator"
name = "sphinx"
+version = "3.3.1"
+description = "Python documentation generator"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "3.3.1"
[package.dependencies]
-Jinja2 = ">=2.3"
-Pygments = ">=2.0"
alabaster = ">=0.7,<0.8"
babel = ">=1.3"
-colorama = ">=0.3.5"
+colorama = {version = ">=0.3.5", markers = "sys_platform == \"win32\""}
docutils = ">=0.12"
imagesize = "*"
+Jinja2 = ">=2.3"
packaging = "*"
+Pygments = ">=2.0"
requests = ">=2.5.0"
-setuptools = "*"
snowballstemmer = ">=1.1"
sphinxcontrib-applehelp = "*"
sphinxcontrib-devhelp = "*"
@@ -1645,12 +1625,12 @@ lint = ["flake8 (>=3.5.0)", "flake8-import-order", "mypy (>=0.790)", "docutils-s
test = ["pytest", "pytest-cov", "html5lib", "typed-ast", "cython"]
[[package]]
-category = "main"
-description = "Type hints (PEP 484) support for the Sphinx autodoc extension"
name = "sphinx-autodoc-typehints"
+version = "1.11.1"
+description = "Type hints (PEP 484) support for the Sphinx autodoc extension"
+category = "main"
optional = false
python-versions = ">=3.5.2"
-version = "1.11.1"
[package.dependencies]
Sphinx = ">=3.0"
@@ -1660,153 +1640,153 @@ test = ["pytest (>=3.1.0)", "typing-extensions (>=3.5)", "sphobjinv (>=2.0)", "S
type_comments = ["typed-ast (>=1.4.0)"]
[[package]]
-category = "main"
-description = "Read the Docs theme for Sphinx"
name = "sphinx-rtd-theme"
+version = "0.4.3"
+description = "Read the Docs theme for Sphinx"
+category = "main"
optional = false
python-versions = "*"
-version = "0.4.3"
[package.dependencies]
sphinx = "*"
[[package]]
-category = "main"
-description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books"
name = "sphinxcontrib-applehelp"
+version = "1.0.2"
+description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.2"
[package.extras]
lint = ["flake8", "mypy", "docutils-stubs"]
test = ["pytest"]
[[package]]
-category = "main"
-description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document."
name = "sphinxcontrib-devhelp"
+version = "1.0.2"
+description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.2"
[package.extras]
lint = ["flake8", "mypy", "docutils-stubs"]
test = ["pytest"]
[[package]]
-category = "main"
-description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files"
name = "sphinxcontrib-htmlhelp"
+version = "1.0.3"
+description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.3"
[package.extras]
lint = ["flake8", "mypy", "docutils-stubs"]
test = ["pytest", "html5lib"]
[[package]]
-category = "main"
-description = "A sphinx extension which renders display math in HTML via JavaScript"
name = "sphinxcontrib-jsmath"
+version = "1.0.1"
+description = "A sphinx extension which renders display math in HTML via JavaScript"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.1"
[package.extras]
test = ["pytest", "flake8", "mypy"]
[[package]]
-category = "main"
-description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document."
name = "sphinxcontrib-qthelp"
+version = "1.0.3"
+description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.3"
[package.extras]
lint = ["flake8", "mypy", "docutils-stubs"]
test = ["pytest"]
[[package]]
-category = "main"
-description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)."
name = "sphinxcontrib-serializinghtml"
+version = "1.1.4"
+description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.1.4"
[package.extras]
lint = ["flake8", "mypy", "docutils-stubs"]
test = ["pytest"]
[[package]]
-category = "dev"
-description = "Manage dynamic plugins for Python applications"
name = "stevedore"
+version = "3.2.2"
+description = "Manage dynamic plugins for Python applications"
+category = "dev"
optional = false
python-versions = ">=3.6"
-version = "3.2.2"
[package.dependencies]
pbr = ">=2.0.0,<2.1.0 || >2.1.0"
[[package]]
-category = "main"
-description = "A backport of the subprocess module from Python 3 for use on 2.x."
name = "subprocess32"
+version = "3.5.4"
+description = "A backport of the subprocess module from Python 3 for use on 2.x."
+category = "main"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, <4"
-version = "3.5.4"
[[package]]
-category = "main"
-description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
name = "terminado"
+version = "0.9.1"
+description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
+category = "main"
optional = false
python-versions = ">=3.6"
-version = "0.9.1"
[package.dependencies]
-ptyprocess = "*"
-pywinpty = ">=0.5"
+ptyprocess = {version = "*", markers = "os_name != \"nt\""}
+pywinpty = {version = ">=0.5", markers = "os_name == \"nt\""}
tornado = ">=4"
[[package]]
-category = "main"
-description = "Test utilities for code working with files and commands"
name = "testpath"
+version = "0.4.4"
+description = "Test utilities for code working with files and commands"
+category = "main"
optional = false
python-versions = "*"
-version = "0.4.4"
[package.extras]
test = ["pathlib2"]
[[package]]
-category = "main"
-description = "threadpoolctl"
name = "threadpoolctl"
+version = "2.1.0"
+description = "threadpoolctl"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "2.1.0"
[[package]]
-category = "main"
-description = "Python Library for Tom's Obvious, Minimal Language"
name = "toml"
+version = "0.10.2"
+description = "Python Library for Tom's Obvious, Minimal Language"
+category = "main"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-version = "0.10.2"
[[package]]
-category = "main"
-description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
name = "torch"
+version = "1.7.0"
+description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+category = "main"
optional = false
python-versions = ">=3.6.1"
-version = "1.7.0"
[package.dependencies]
dataclasses = "*"
@@ -1815,20 +1795,20 @@ numpy = "*"
typing-extensions = "*"
[[package]]
-category = "main"
-description = "Model summary in PyTorch, based off of the original torchsummary."
name = "torch-summary"
+version = "1.4.3"
+description = "Model summary in PyTorch, based off of the original torchsummary."
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.4.3"
[[package]]
-category = "main"
-description = "image and video datasets and models for torch deep learning"
name = "torchvision"
+version = "0.8.1"
+description = "image and video datasets and models for torch deep learning"
+category = "main"
optional = false
python-versions = "*"
-version = "0.8.1"
[package.dependencies]
numpy = "*"
@@ -1839,31 +1819,31 @@ torch = "1.7.0"
scipy = ["scipy"]
[[package]]
-category = "main"
-description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
name = "tornado"
+version = "6.1"
+description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
+category = "main"
optional = false
python-versions = ">= 3.5"
-version = "6.1"
[[package]]
-category = "main"
-description = "Fast, Extensible Progress Meter"
name = "tqdm"
+version = "4.53.0"
+description = "Fast, Extensible Progress Meter"
+category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
-version = "4.53.0"
[package.extras]
dev = ["py-make (>=0.1.0)", "twine", "argopt", "pydoc-markdown", "wheel"]
[[package]]
-category = "main"
-description = "Traitlets Python configuration system"
name = "traitlets"
+version = "5.0.5"
+description = "Traitlets Python configuration system"
+category = "main"
optional = false
python-versions = ">=3.7"
-version = "5.0.5"
[package.dependencies]
ipython-genutils = "*"
@@ -1872,76 +1852,76 @@ ipython-genutils = "*"
test = ["pytest"]
[[package]]
-category = "dev"
-description = "a fork of Python 2 and 3 ast modules with type comment support"
name = "typed-ast"
+version = "1.4.1"
+description = "a fork of Python 2 and 3 ast modules with type comment support"
+category = "dev"
optional = false
python-versions = "*"
-version = "1.4.1"
[[package]]
-category = "dev"
-description = "Run-time type checker for Python"
name = "typeguard"
+version = "2.10.0"
+description = "Run-time type checker for Python"
+category = "dev"
optional = false
python-versions = ">=3.5.3"
-version = "2.10.0"
[package.extras]
doc = ["sphinx-rtd-theme", "sphinx-autodoc-typehints (>=1.2.0)"]
test = ["pytest", "typing-extensions"]
[[package]]
-category = "main"
-description = "Backported and Experimental Type Hints for Python 3.5+"
name = "typing-extensions"
+version = "3.7.4.3"
+description = "Backported and Experimental Type Hints for Python 3.5+"
+category = "main"
optional = false
python-versions = "*"
-version = "3.7.4.3"
[[package]]
-category = "main"
-description = "Runtime inspection utilities for typing module."
name = "typing-inspect"
+version = "0.6.0"
+description = "Runtime inspection utilities for typing module."
+category = "main"
optional = false
python-versions = "*"
-version = "0.6.0"
[package.dependencies]
mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4"
[[package]]
-category = "main"
-description = "HTTP library with thread-safe connection pooling, file post, and more."
name = "urllib3"
+version = "1.26.2"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
-version = "1.26.2"
[package.extras]
brotli = ["brotlipy (>=0.6.0)"]
secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"]
-socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7,<2.0)"]
+socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
-category = "main"
-description = "A CLI and library for interacting with the Weights and Biases API."
name = "wandb"
+version = "0.10.12"
+description = "A CLI and library for interacting with the Weights and Biases API."
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "0.10.12"
[package.dependencies]
Click = ">=7.0"
-GitPython = ">=1.0.0"
-PyYAML = "*"
configparser = ">=3.8.1"
docker-pycreds = ">=0.4.0"
+GitPython = ">=1.0.0"
promise = ">=2.0,<3"
protobuf = ">=3.12.0"
psutil = ">=5.0.0"
python-dateutil = ">=2.6.1"
+PyYAML = "*"
requests = ">=2.0.0,<3"
sentry-sdk = ">=0.4.0"
shortuuid = ">=0.5.0"
@@ -1952,16 +1932,16 @@ watchdog = ">=0.8.3"
[package.extras]
aws = ["boto3"]
gcp = ["google-cloud-storage"]
-grpc = ["grpcio (1.27.2)"]
+grpc = ["grpcio (==1.27.2)"]
kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"]
[[package]]
-category = "main"
-description = "Filesystem events monitoring"
name = "watchdog"
+version = "0.10.4"
+description = "Filesystem events monitoring"
+category = "main"
optional = false
python-versions = "*"
-version = "0.10.4"
[package.dependencies]
pathtools = ">=0.1.1"
@@ -1970,51 +1950,50 @@ pathtools = ">=0.1.1"
watchmedo = ["PyYAML (>=3.10)", "argh (>=0.24.1)"]
[[package]]
-category = "main"
-description = "Measures the displayed width of unicode strings in a terminal"
name = "wcwidth"
+version = "0.2.5"
+description = "Measures the displayed width of unicode strings in a terminal"
+category = "main"
optional = false
python-versions = "*"
-version = "0.2.5"
[[package]]
-category = "main"
-description = "Character encoding aliases for legacy web content"
name = "webencodings"
+version = "0.5.1"
+description = "Character encoding aliases for legacy web content"
+category = "main"
optional = false
python-versions = "*"
-version = "0.5.1"
[[package]]
-category = "dev"
-description = "IPython HTML widgets for Jupyter"
name = "widgetsnbextension"
+version = "3.5.1"
+description = "IPython HTML widgets for Jupyter"
+category = "dev"
optional = false
python-versions = "*"
-version = "3.5.1"
[package.dependencies]
notebook = ">=4.4.1"
[[package]]
-category = "main"
-description = "A small Python utility to set file creation time on Windows"
-marker = "sys_platform == \"win32\""
name = "win32-setctime"
+version = "1.0.3"
+description = "A small Python utility to set file creation time on Windows"
+category = "main"
optional = false
python-versions = ">=3.5"
-version = "1.0.3"
[package.extras]
dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
[[package]]
-category = "dev"
-description = "A rewrite of the builtin doctest module"
name = "xdoctest"
+version = "0.12.0"
+description = "A rewrite of the builtin doctest module"
+category = "dev"
optional = false
python-versions = "*"
-version = "0.12.0"
[package.dependencies]
six = "*"
@@ -2025,9 +2004,9 @@ optional = ["pygments", "colorama"]
tests = ["pytest", "pytest-cov", "codecov", "scikit-build", "cmake", "ninja", "pybind11"]
[metadata]
-content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0"
-lock-version = "1.0"
+lock-version = "1.1"
python-versions = "^3.8"
+content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0"
[metadata.files]
alabaster = [
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 0e4b298..2d6b43c 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -25,16 +25,71 @@
},
{
"cell_type": "code",
- "execution_count": 53,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks import CNN"
+ "from text_recognizer.networks import CNN, TDS2d"
]
},
{
"cell_type": "code",
- "execution_count": 63,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tds2d = TDS2d(**{\n",
+ " \"depth\" : 4,\n",
+ " \"tds_groups\" : [\n",
+ " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n",
+ " ],\n",
+ " \"kernel_size\" : [5, 7],\n",
+ " \"dropout_rate\" : 0.1\n",
+ " }, input_dim=32, output_dim=128)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tds2d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2,1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tds2d(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -43,7 +98,7 @@
},
{
"cell_type": "code",
- "execution_count": 79,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -52,54 +107,25 @@
},
{
"cell_type": "code",
- "execution_count": 81,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Sequential(\n",
- " (0): Sequential(\n",
- " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (1): Sequential(\n",
- " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 81,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"nn.Sequential(i,i)"
]
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 128, 1, 59])"
- ]
- },
- "execution_count": 64,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"cnn(t).shape"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -108,7 +134,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -117,7 +143,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -126,7 +152,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -135,160 +161,34 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "29.5"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"5 * 59 / 10"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1, 28, 952])"
- ]
- },
- "execution_count": 25,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "===============================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "===============================================================================================\n",
- "├─Encoder: 1-1 [-1, 32, 5, 59] --\n",
- "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n",
- "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n",
- "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n",
- "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n",
- "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n",
- "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n",
- "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n",
- "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n",
- "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n",
- "├─Decoder: 1-2 [-1, 1, 28, 952] --\n",
- "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n",
- "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n",
- "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n",
- "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n",
- "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n",
- "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n",
- "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n",
- "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n",
- "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n",
- "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n",
- "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n",
- "===============================================================================================\n",
- "Total params: 4,343,905\n",
- "Trainable params: 4,343,905\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (G): 1.76\n",
- "===============================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 9.32\n",
- "Params size (MB): 16.57\n",
- "Estimated Total Size (MB): 26.00\n",
- "===============================================================================================\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "===============================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "===============================================================================================\n",
- "├─Encoder: 1-1 [-1, 32, 5, 59] --\n",
- "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n",
- "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n",
- "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n",
- "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n",
- "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n",
- "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n",
- "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n",
- "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n",
- "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n",
- "├─Decoder: 1-2 [-1, 1, 28, 952] --\n",
- "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n",
- "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n",
- "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n",
- "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n",
- "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n",
- "| └─Sequential: 2 [] --\n",
- "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n",
- "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n",
- "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n",
- "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n",
- "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n",
- "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n",
- "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n",
- "===============================================================================================\n",
- "Total params: 4,343,905\n",
- "Trainable params: 4,343,905\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (G): 1.76\n",
- "===============================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 9.32\n",
- "Params size (MB): 16.57\n",
- "Estimated Total Size (MB): 26.00\n",
- "==============================================================================================="
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)"
]
},
{
"cell_type": "code",
- "execution_count": 94,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -297,47 +197,25 @@
},
{
"cell_type": "code",
- "execution_count": 107,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 4, 59])"
- ]
- },
- "execution_count": 107,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"up(tt).shape"
]
},
{
"cell_type": "code",
- "execution_count": 104,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 1, 59])"
- ]
- },
- "execution_count": 104,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"tt.shape"
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -353,7 +231,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -362,27 +240,16 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 30, 2048])"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"e(t).shape"
]
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -391,7 +258,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -401,7 +268,7 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -410,7 +277,7 @@
},
{
"cell_type": "code",
- "execution_count": 46,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -419,75 +286,34 @@
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "ename": "ModuleAttributeError",
- "evalue": "'Embedding' object has no attribute 'device'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-58-657f11e4a017>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0memb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-N1c_zsdp-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 779\u001b[0m type(self).__name__, name))\n\u001b[1;32m 780\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mModuleAttributeError\u001b[0m: 'Embedding' object has no attribute 'device'"
- ]
- }
- ],
+ "outputs": [],
"source": [
"emb.device"
]
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
- " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
- " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
- " ...,\n",
- " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
- " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
- " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005]])"
- ]
- },
- "execution_count": 49,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"ee"
]
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 256])"
- ]
- },
- "execution_count": 47,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"ee.shape"
]
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -496,27 +322,16 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 10, 256])"
- ]
- },
- "execution_count": 52,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"t.shape"
]
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -525,47 +340,25 @@
},
{
"cell_type": "code",
- "execution_count": 57,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 12, 256])"
- ]
- },
- "execution_count": 57,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"t.shape"
]
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 256])"
- ]
- },
- "execution_count": 42,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"e.shape"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -574,7 +367,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -583,7 +376,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -603,7 +396,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -612,7 +405,7 @@
},
{
"cell_type": "code",
- "execution_count": 66,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -621,90 +414,16 @@
},
{
"cell_type": "code",
- "execution_count": 52,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 64, 12, 474] --\n",
- "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
- "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
- "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n",
- "├─Sequential: 1-2 [-1, 68, 1, 30] --\n",
- "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n",
- "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n",
- "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n",
- "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n",
- "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n",
- "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n",
- "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n",
- "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n",
- "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n",
- "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n",
- "==========================================================================================\n",
- "Total params: 805,910\n",
- "Trainable params: 805,910\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 21.05\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 5.55\n",
- "Params size (MB): 3.07\n",
- "Estimated Total Size (MB): 8.73\n",
- "==========================================================================================\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 64, 12, 474] --\n",
- "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
- "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
- "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n",
- "├─Sequential: 1-2 [-1, 68, 1, 30] --\n",
- "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n",
- "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n",
- "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n",
- "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n",
- "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n",
- "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n",
- "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n",
- "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n",
- "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n",
- "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n",
- "==========================================================================================\n",
- "Total params: 805,910\n",
- "Trainable params: 805,910\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 21.05\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 5.55\n",
- "Params size (MB): 3.07\n",
- "Estimated Total Size (MB): 8.73\n",
- "=========================================================================================="
- ]
- },
- "execution_count": 52,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -715,187 +434,25 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Sequential(\n",
- " (0): SELU(inplace=True)\n",
- " (1): Sequential(\n",
- " (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (2): SELU(inplace=True)\n",
- " (3): MaxPool2d(kernel_size=(2, 4), stride=2, padding=1, dilation=1, ceil_mode=False)\n",
- " )\n",
- " (2): Sequential(\n",
- " (0): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (1): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " )\n",
- " (shortcut): Sequential(\n",
- " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (2): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " )\n",
- " (shortcut): Sequential(\n",
- " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (3): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " )\n",
- " (shortcut): Sequential(\n",
- " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"backbone"
]
},
{
"cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 64, 7, 237] --\n",
- "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
- "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
- "├─SELU: 1-2 [-1, 64, 12, 474] --\n",
- "├─Sequential: 1 [] --\n",
- "| └─SELU: 2-3 [-1, 64, 12, 474] --\n",
- "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n",
- "├─Sequential: 1-3 [-1, 256, 1, 30] --\n",
- "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n",
- "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n",
- "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n",
- "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n",
- "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n",
- "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n",
- "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n",
- "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n",
- "==========================================================================================\n",
- "Total params: 2,471,488\n",
- "Trainable params: 2,471,488\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 27.71\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 5.55\n",
- "Params size (MB): 9.43\n",
- "Estimated Total Size (MB): 15.08\n",
- "==========================================================================================\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 64, 7, 237] --\n",
- "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
- "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
- "├─SELU: 1-2 [-1, 64, 12, 474] --\n",
- "├─Sequential: 1 [] --\n",
- "| └─SELU: 2-3 [-1, 64, 12, 474] --\n",
- "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n",
- "├─Sequential: 1-3 [-1, 256, 1, 30] --\n",
- "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n",
- "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n",
- "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n",
- "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n",
- "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n",
- "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n",
- "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n",
- "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n",
- "==========================================================================================\n",
- "Total params: 2,471,488\n",
- "Trainable params: 2,471,488\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 27.71\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 5.55\n",
- "Params size (MB): 9.43\n",
- "Estimated Total Size (MB): 15.08\n",
- "=========================================================================================="
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"summary(wr, (1, 28, 952), device=\"cpu\", depth=3)"
]
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -904,7 +461,7 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -913,7 +470,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -922,7 +479,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -931,7 +488,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -940,7 +497,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -949,67 +506,34 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 119, 256, 1])"
- ]
- },
- "execution_count": 43,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"d.shape"
]
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 119, 256])"
- ]
- },
- "execution_count": 44,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"d.squeeze(3).shape"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 256, 4, 119])"
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"b.shape"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1018,47 +542,25 @@
},
{
"cell_type": "code",
- "execution_count": 70,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "96"
- ]
- },
- "execution_count": 70,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"32 + 64"
]
},
{
"cell_type": "code",
- "execution_count": 106,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "336"
- ]
- },
- "execution_count": 106,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"3 * 112"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1067,7 +569,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1076,62 +578,29 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": null,
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 196, 128])"
- ]
- },
- "execution_count": 42,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape"
]
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 196, 128])"
- ]
- },
- "execution_count": 44,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape"
]
},
{
"cell_type": "code",
- "execution_count": 60,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 4, 196, 256])"
- ]
- },
- "execution_count": 60,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
" torch.cat(\n",
" [\n",
@@ -1144,27 +613,16 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "784"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"4 * 196"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1173,40 +631,18 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "8"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"torch.nonzero(target == 9, as_tuple=False)[0].item()"
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([ 1, 1, 12, 1, 1, 1, 1, 1, 9])"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"target[:9]"
]
@@ -1220,27 +656,16 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "inf"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"np.inf"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1249,22 +674,9 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 1080x360 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
"plt.figure(figsize=(15, 5))\n",
"pe = PositionalEncoding(20, 0)\n",
@@ -1276,7 +688,7 @@
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1285,7 +697,7 @@
},
{
"cell_type": "code",
- "execution_count": 57,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1294,110 +706,25 @@
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "27.0"
- ]
- },
- "execution_count": 58,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"216 / 8"
]
},
{
"cell_type": "code",
- "execution_count": 59,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 80] --\n",
- "| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n",
- "| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n",
- "| └─ReLU: 2-3 [-1, 24, 28, 952] --\n",
- "| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n",
- "| └─_Transition: 2-5 [-1, 48, 14, 476] --\n",
- "| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n",
- "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n",
- "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n",
- "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n",
- "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n",
- "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n",
- "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n",
- "| └─Rearrange: 2-11 [-1, 216] --\n",
- "| └─Linear: 2-12 [-1, 80] 17,360\n",
- "==========================================================================================\n",
- "Total params: 41,240\n",
- "Trainable params: 41,240\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 252.43\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 53.69\n",
- "Params size (MB): 0.16\n",
- "Estimated Total Size (MB): 53.95\n",
- "==========================================================================================\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 80] --\n",
- "| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n",
- "| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n",
- "| └─ReLU: 2-3 [-1, 24, 28, 952] --\n",
- "| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n",
- "| └─_Transition: 2-5 [-1, 48, 14, 476] --\n",
- "| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n",
- "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n",
- "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n",
- "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n",
- "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n",
- "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n",
- "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n",
- "| └─Rearrange: 2-11 [-1, 216] --\n",
- "| └─Linear: 2-12 [-1, 80] 17,360\n",
- "==========================================================================================\n",
- "Total params: 41,240\n",
- "Trainable params: 41,240\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 252.43\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 53.69\n",
- "Params size (MB): 0.16\n",
- "Estimated Total Size (MB): 53.95\n",
- "=========================================================================================="
- ]
- },
- "execution_count": 59,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)"
]
},
{
"cell_type": "code",
- "execution_count": 84,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1408,27 +735,16 @@
},
{
"cell_type": "code",
- "execution_count": 85,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Sequential()"
- ]
- },
- "execution_count": 85,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"backbone"
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1437,7 +753,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1455,74 +771,16 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 512, 2, 60] --\n",
- "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n",
- "| └─Sequential: 2-2 [-1, 32, 28, 952] 18,560\n",
- "| └─Sequential: 2-3 [-1, 64, 14, 476] 57,536\n",
- "| └─Sequential: 2-4 [-1, 128, 7, 238] 229,760\n",
- "| └─Sequential: 2-5 [-1, 256, 4, 119] 918,272\n",
- "| └─Sequential: 2-6 [-1, 512, 2, 60] 3,671,552\n",
- "==========================================================================================\n",
- "Total params: 4,895,968\n",
- "Trainable params: 4,895,968\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 22.36\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 6.51\n",
- "Params size (MB): 18.68\n",
- "Estimated Total Size (MB): 25.29\n",
- "==========================================================================================\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 512, 2, 60] --\n",
- "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n",
- "| └─Sequential: 2-2 [-1, 32, 28, 952] 18,560\n",
- "| └─Sequential: 2-3 [-1, 64, 14, 476] 57,536\n",
- "| └─Sequential: 2-4 [-1, 128, 7, 238] 229,760\n",
- "| └─Sequential: 2-5 [-1, 256, 4, 119] 918,272\n",
- "| └─Sequential: 2-6 [-1, 512, 2, 60] 3,671,552\n",
- "==========================================================================================\n",
- "Total params: 4,895,968\n",
- "Trainable params: 4,895,968\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 22.36\n",
- "==========================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 6.51\n",
- "Params size (MB): 18.68\n",
- "Estimated Total Size (MB): 25.29\n",
- "=========================================================================================="
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"summary(w, (1, 28, 952), device=\"cpu\", depth=2)"
]
},
{
"cell_type": "code",
- "execution_count": 46,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1531,7 +789,7 @@
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1541,7 +799,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1551,71 +809,34 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([100, 1, 256])"
- ]
- },
- "execution_count": 52,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"h.flatten(2).permute(2, 0, 1).shape"
]
},
{
"cell_type": "code",
- "execution_count": 91,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([100, 1, 256])"
- ]
- },
- "execution_count": 91,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"h.flatten(2).permute(2, 0, 1).shape"
]
},
{
"cell_type": "code",
- "execution_count": 48,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[0., -inf, -inf, -inf, -inf],\n",
- " [0., 0., -inf, -inf, -inf],\n",
- " [0., 0., 0., -inf, -inf],\n",
- " [0., 0., 0., 0., -inf],\n",
- " [0., 0., 0., 0., 0.]])"
- ]
- },
- "execution_count": 48,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mask\n"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1625,7 +846,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1634,71 +855,34 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([ True, True, True, True, True, True, False, False, False, False,\n",
- " True, True, True, True, True, True, False, False, False, False,\n",
- " True, True, True, True, True, True, False, False, False, False])"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mask"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([ 1, 21, 2, 45, 31, 81, 0, 0, 0, 0, 2, 1, 1, 1, 1, 81, 0, 0,\n",
- " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"pred * mask"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([ 1, 1, 1, 1, 1, 81, 0, 0, 0, 0, 1, 1, 1, 1, 1, 81, 0, 0,\n",
- " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"target * mask"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1707,7 +891,7 @@
},
{
"cell_type": "code",
- "execution_count": 61,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1716,7 +900,7 @@
},
{
"cell_type": "code",
- "execution_count": 76,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1725,27 +909,16 @@
},
{
"cell_type": "code",
- "execution_count": 66,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "30"
- ]
- },
- "execution_count": 66,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"target.shape[0]"
]
},
{
"cell_type": "code",
- "execution_count": 84,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1754,27 +927,16 @@
},
{
"cell_type": "code",
- "execution_count": 85,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([10, 20, 30])"
- ]
- },
- "execution_count": 85,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"t2"
]
},
{
"cell_type": "code",
- "execution_count": 89,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1784,160 +946,72 @@
},
{
"cell_type": "code",
- "execution_count": 90,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([ 1, 1, 1, 1, 1, 81, 79, 79, 79, 79, 2, 1, 1, 1, 1, 81, 79, 79,\n",
- " 79, 79, 1, 1, 1, 1, 1, 81, 79, 79, 79, 79])"
- ]
- },
- "execution_count": 90,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"pred"
]
},
{
"cell_type": "code",
- "execution_count": 88,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "ename": "SyntaxError",
- "evalue": "invalid syntax (<ipython-input-88-b8a4aef86401>, line 1)",
- "output_type": "error",
- "traceback": [
- "\u001b[0;36m File \u001b[0;32m\"<ipython-input-88-b8a4aef86401>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m [pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]"
]
},
{
"cell_type": "code",
- "execution_count": 69,
+ "execution_count": null,
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[ 6],\n",
- " [ 7],\n",
- " [ 8],\n",
- " [ 9],\n",
- " [16],\n",
- " [17],\n",
- " [18],\n",
- " [19],\n",
- " [26],\n",
- " [27],\n",
- " [28],\n",
- " [29]])"
- ]
- },
- "execution_count": 69,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"pad_indcies"
]
},
{
"cell_type": "code",
- "execution_count": 71,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "ename": "TypeError",
- "evalue": "only integer tensors of a single element can be converted to an index",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-71-39b5cc3b1445>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m79\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index"
- ]
- }
- ],
+ "outputs": [],
"source": [
"pred[pad_indcies:pad_indcies] = 79"
]
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([20])"
- ]
- },
- "execution_count": 50,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"pred.shape"
]
},
{
"cell_type": "code",
- "execution_count": 51,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([20])"
- ]
- },
- "execution_count": 51,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"target.shape"
]
},
{
"cell_type": "code",
- "execution_count": 91,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0.0"
- ]
- },
- "execution_count": 91,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"accuracy(pred, target)"
]
},
{
"cell_type": "code",
- "execution_count": 92,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1946,20 +1020,9 @@
},
{
"cell_type": "code",
- "execution_count": 93,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor(0.9667)"
- ]
- },
- "execution_count": 93,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"acc"
]
diff --git a/src/notebooks/07-try-gtn.ipynb b/src/notebooks/07-try-gtn.ipynb
index d366dec..4ef444b 100644
--- a/src/notebooks/07-try-gtn.ipynb
+++ b/src/notebooks/07-try-gtn.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -125,6 +125,53 @@
},
{
"cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1.0, 0.0, 0.5, 0.5]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import gtn\n",
+ "\n",
+ "# Make some graphs:\n",
+ "g1 = gtn.Graph()\n",
+ "g1.add_node(True) # Add a start node\n",
+ "g1.add_node() # Add an internal node\n",
+ "g1.add_node(False, True) # Add an accepting node\n",
+ "\n",
+ "# Add arcs with (src node, dst node, label):\n",
+ "g1.add_arc(0, 1, 1)\n",
+ "g1.add_arc(0, 1, 2)\n",
+ "g1.add_arc(1, 2, 1)\n",
+ "g1.add_arc(1, 2, 0)\n",
+ "\n",
+ "g2 = gtn.Graph()\n",
+ "g2.add_node(True, True)\n",
+ "g2.add_arc(0, 0, 1)\n",
+ "g2.add_arc(0, 0, 0)\n",
+ "\n",
+ "# Compute a function of the graphs:\n",
+ "intersection = gtn.intersect(g1, g2)\n",
+ "score = gtn.forward_score(intersection)\n",
+ "\n",
+ "# Visualize the intersected graph:\n",
+ "gtn.draw(intersection, \"intersection.pdf\")\n",
+ "\n",
+ "# Backprop:\n",
+ "gtn.backward(score)\n",
+ "\n",
+ "# Print gradients of arc weights \n",
+ "print(g1.grad().weights_to_list()) # [1.0, 0.0, 1.0, 0.0]"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb
new file mode 100644
index 0000000..841a37d
--- /dev/null
+++ b/src/notebooks/Untitled.ipynb
@@ -0,0 +1,385 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.datasets import IamLinesDataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transform = [{\"type\": \"ToPILImage\", \"args\": None}, \n",
+ " #{\"type\": \"RandomResizeCrop\", \"args\": None}, \n",
+ " {\"type\": \"RandomRotation\", \"args\": {\"degrees\": 0.8, \"fill\": 0}}, \n",
+ " {\"type\": \"ColorJitter\", \"args\": {\"brightness\": 0.5, \"contrast\": 0.5, \"saturation\": 0.5, \"hue\": 0.5}}, \n",
+ " {\"type\": \"ToTensor\", \"args\": None}, \n",
+ " {\"type\": \"Normalize\", \"args\": {\"mean\": [0.912], \"std\": 0.168}},\n",
+ " #{\"type\": \"RandomAffine\", \"args\": {\"degrees\": [-0.25, 0.25], \"scale\": [0.98, 1.0]}}\n",
+ " ]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target_transforms = [\n",
+ " {\"type\": \"ToLower\", \"args\": None},\n",
+ " {\"type\": \"ToCharcters\", \"args\": {\"pad_token\": \"_\", \"eos_token\": \"</s>\"}},\n",
+ " {\"type\": \"ToWordPieces\", \"args\": {\n",
+ " \"num_features\": 64, \n",
+ " \"tokens\": \"iamdb_1kwp_tokens_1000.txt\", \n",
+ " \"lexicon\": \"iamdb_1kwp_lex_1000.txt\",\n",
+ " \"use_words\": False,\n",
+ " \"prepend_wordsep\": False,\n",
+ " }\n",
+ " }\n",
+ " \n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.datasets.transforms import ToText"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-02-24 21:43:47.687 | DEBUG | text_recognizer.datasets.transforms:__init__:201 - Using data dir: /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/raw/iam/iamdb\n"
+ ]
+ }
+ ],
+ "source": [
+ "to_text = ToText(\n",
+ " num_features= 64, \n",
+ " tokens=\"iamdb_1kwp_tokens_1000.txt\", \n",
+ " lexicon=\"iamdb_1kwp_lex_1000.txt\",\n",
+ " use_words=False,\n",
+ " prepend_wordsep= False,)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-02-24 21:42:02.700 | DEBUG | text_recognizer.datasets.transforms:__init__:201 - Using data dir: /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/raw/iam/iamdb\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IAM Lines Dataset\n",
+ "Number classes: 54\n",
+ "Mapping: {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'a', 11: 'b', 12: 'c', 13: 'd', 14: 'e', 15: 'f', 16: 'g', 17: 'h', 18: 'i', 19: 'j', 20: 'k', 21: 'l', 22: 'm', 23: 'n', 24: 'o', 25: 'p', 26: 'q', 27: 'r', 28: 's', 29: 't', 30: 'u', 31: 'v', 32: 'w', 33: 'x', 34: 'y', 35: 'z', 36: ' ', 37: '!', 38: '\"', 39: '#', 40: '&', 41: \"'\", 42: '(', 43: ')', 44: '*', 45: '+', 46: ',', 47: '-', 48: '.', 49: '/', 50: ':', 51: ';', 52: '?', 53: '_'}\n",
+ "Data: (1861, 28, 952)\n",
+ "Targets: (1861, 97)\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = IamLinesDataset(train=False, pad_token=\"_\", transform=transform, target_transform=target_transforms, lower=True)\n",
+ "dataset.load_or_generate_data()\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "but▁since▁starting▁salaries▁would▁depend▁on▁grade▁a\n",
+ "or▁b▁in▁the▁finals▁next▁may,▁and▁since▁mating\n",
+ "prospects▁would▁depend▁upon▁salaries,▁scholarship▁for\n",
+ "these▁fine▁young▁people▁was▁closely▁geared▁to\n",
+ "economic▁and▁biological▁ends▁which,▁essentially,\n",
+ "were▁really▁means.▁so,▁seeing▁them▁revolve▁in\n",
+ "circles,▁harry▁had▁the▁feeling▁that▁moke▁(or▁what\n",
+ "moke▁consciously▁or▁unconsciously▁symbolised,▁any-\n",
+ "way▁in▁harry's▁mind)▁had▁these▁splendid▁young\n",
+ "people▁by▁the▁short▁hairs,▁and▁was▁diverting▁them▁...\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "for i in range(190, 200):\n",
+ " plt.figure(figsize=(20, 20))\n",
+ " plt.xticks([])\n",
+ " plt.yticks([])\n",
+ " data, target = dataset[i]\n",
+ "# print(target)\n",
+ " print(to_text(target))\n",
+ "# target = [x - 26 if x > 35 else x for x in target]\n",
+ "# sentence = convert_y_label_to_string(target, dataset) \n",
+ "# print(target)\n",
+ "# plt.title(sentence)\n",
+ " plt.imshow(data.squeeze(0).numpy(), cmap='gray')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target.tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset.target_transform"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.transducer import load_transducer_loss, Transducer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t, i =load_transducer_loss(64, \n",
+ " 0,\n",
+ " \"iamdb_1kwp_tokens_1000.txt\", \n",
+ " \"iamdb_1kwp_lex_1000.txt\",\n",
+ " \"1kwp_prune_0_0_optblank.bin\",\n",
+ " \"optional\",\n",
+ " False,\n",
+ " False,\n",
+ " False,\n",
+ " None,\n",
+ " \"mean\"\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t(target, target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/notebooks/intersection.pdf b/src/notebooks/intersection.pdf
new file mode 100644
index 0000000..c425a9f
--- /dev/null
+++ b/src/notebooks/intersection.pdf
Binary files differ
diff --git a/src/tasks/build_transitions.py b/src/tasks/build_transitions.py
index b12c9bc..91f8c1a 100644
--- a/src/tasks/build_transitions.py
+++ b/src/tasks/build_transitions.py
@@ -9,7 +9,7 @@ Most code stolen from here:
import collections
import itertools
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
import click
import gtn
@@ -18,7 +18,7 @@ from loguru import logger
START_IDX = -1
END_IDX = -2
-WORDSEP = "_"
+WORDSEP = "▁"
def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph:
@@ -27,7 +27,7 @@ def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph:
ngram = len(ngrams)
state_to_node = {}
- def get_node(state: Optional[List]) -> gtn.node:
+ def get_node(state: Optional[List]) -> Any:
node = state_to_node.get(state, None)
if node is not None:
diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py
index f605920..2ac0e2c 100644
--- a/src/tasks/make_wordpieces.py
+++ b/src/tasks/make_wordpieces.py
@@ -30,7 +30,7 @@ def iamdb_pieces(
user_symbols=["/"], # added so token is in the output set
)
- vocab = sorted(set(w for t in text for w in t.split("_") if w))
+ vocab = sorted(set(w for t in text for w in t.split("▁") if w))
if "move" not in vocab:
raise RuntimeError("`MOVE` not in vocab")
diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py
index 5a5136c..a93eb00 100644
--- a/src/text_recognizer/datasets/iam_preprocessor.py
+++ b/src/text_recognizer/datasets/iam_preprocessor.py
@@ -59,7 +59,7 @@ class Preprocessor:
use_words: bool = False,
prepend_wordsep: bool = False,
) -> None:
- self.wordsep = "_"
+ self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 60987e0..b6a48f5 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -1,6 +1,10 @@
"""Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
import random
+from typing import Any, Optional, Union
+from loguru import logger
import numpy as np
from PIL import Image
import torch
@@ -18,6 +22,7 @@ from torchvision.transforms import (
ToTensor,
)
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
from text_recognizer.datasets.util import EmnistMapper
@@ -145,3 +150,117 @@ class ToLower:
"""Corrects index value in target tensor."""
device = target.device
return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+ """Converts integers to characters."""
+
+ def __init__(
+ self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+ ) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=lower,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+ )
+
+ def __call__(self, y: Tensor) -> str:
+ """Converts a Tensor to a str."""
+ return (
+ "".join([self.emnist_mapper(int(i)) for i in y])
+ .strip("_")
+ .replace(" ", "▁")
+ )
+
+
+class WordPieces:
+ """Abstract transform for word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ self.preprocessor = Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ )
+
+ @abstractmethod
+ def __call__(self, *args, **kwargs) -> Any:
+ """Transforms input."""
+ ...
+
+
+class ToWordPieces(WordPieces):
+ """Transforms str to word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, line: str) -> Tensor:
+ """Transforms str to word pieces."""
+ return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+ """Takes word pieces and converts them to text."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, x: Tensor) -> str:
+ """Converts tensor to text."""
+ return self.preprocessor.to_text(x.tolist())
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index bac5d28..1521355 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -8,7 +8,7 @@ from .lenet import LeNet
from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
-from .transducer import TDS2d
+from .transducer import load_transducer_loss, TDS2d
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
@@ -28,6 +28,7 @@ __all__ = [
"greedy_decoder",
"MLP",
"LeNet",
+ "load_transducer_loss",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 7133c26..a2d7926 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -112,11 +112,11 @@ class CNNTransformer(nn.Module):
if self.max_pool is not None:
src = self.max_pool(src)
- if self.adaptive_pool is not None:
+ if self.adaptive_pool is not None and len(src.shape) == 4:
src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
src = src.squeeze(3)
- else:
+ elif len(src.shape) == 4:
src = rearrange(src, "b c h w -> b (h w) c")
b, t, _ = src.shape
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
index fdd6662..8c19a01 100644
--- a/src/text_recognizer/networks/transducer/__init__.py
+++ b/src/text_recognizer/networks/transducer/__init__.py
@@ -1,2 +1,3 @@
"""Transducer modules."""
from .tds_conv import TDS2d
+from .transducer import load_transducer_loss, Transducer
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
index 018caf2..5fb8ba9 100644
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ b/src/text_recognizer/networks/transducer/tds_conv.py
@@ -136,8 +136,10 @@ class TDS2d(nn.Module):
self.tds = None
self.fc = None
- def _build_network(self) -> None:
+ self._build_network()
+ def _build_network(self) -> None:
+ in_channels = self.in_channels
modules = []
stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
if self.input_dim % stride_h:
@@ -151,7 +153,7 @@ class TDS2d(nn.Module):
modules.extend(
[
nn.Conv2d(
- in_channels=self.in_channels,
+ in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
@@ -173,12 +175,10 @@ class TDS2d(nn.Module):
)
)
- self.in_channels = out_channels
+ in_channels = out_channels
self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(
- self.in_channels * self.input_dim // stride_h, self.output_dim
- )
+ self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass.
@@ -193,6 +193,9 @@ class TDS2d(nn.Module):
Tensor: Output tensor.
"""
+ if len(x.shape) == 4:
+ x = x.squeeze(1) # Squeeze the channel dim away.
+
B, H, W = x.shape
x = rearrange(
x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
new file mode 100644
index 0000000..cadcecc
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/test.py
@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+
+from text_recognizer.networks.transducer import load_transducer_loss, Transducer
+import unittest
+
+
+class TestTransducer(unittest.TestCase):
+ def test_viterbi(self):
+ T = 5
+ N = 4
+ B = 2
+
+ # fmt: off
+ emissions1 = torch.tensor((
+ 0, 4, 0, 1,
+ 0, 2, 1, 1,
+ 0, 0, 0, 2,
+ 0, 0, 0, 2,
+ 8, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ emissions2 = torch.tensor((
+ 0, 2, 1, 7,
+ 0, 2, 9, 1,
+ 0, 0, 0, 2,
+ 0, 0, 5, 2,
+ 1, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ # fmt: on
+
+ # Test without blank:
+ labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
+ transducer = Transducer(
+ tokens=["a", "b", "c", "d"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
+ blank="none",
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+ # Test with blank without repeats:
+ labels = [[1, 0], [2, 2]]
+ transducer = Transducer(
+ tokens=["a", "b", "c"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2},
+ blank="optional",
+ allow_repeats=False,
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py
new file mode 100644
index 0000000..d7e3d08
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/transducer.py
@@ -0,0 +1,410 @@
+"""Transducer and the transducer loss function.py
+
+Stolen from:
+ https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
+
+"""
+from pathlib import Path
+import itertools
+from typing import Dict, List, Optional, Union, Tuple
+
+from loguru import logger
+import gtn
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+def make_scalar_graph(weight) -> gtn.Graph:
+ scalar = gtn.Graph()
+ scalar.add_node(True)
+ scalar.add_node(False, True)
+ scalar.add_arc(0, 1, 0, 0, weight)
+ return scalar
+
+
+def make_chain_graph(sequence) -> gtn.Graph:
+ graph = gtn.Graph(False)
+ graph.add_node(True)
+ for i, s in enumerate(sequence):
+ graph.add_node(False, i == (len(sequence) - 1))
+ graph.add_arc(i, i + 1, s)
+ return graph
+
+
+def make_transitions_graph(
+ ngram: int, num_tokens: int, calc_grad: bool = False
+) -> gtn.Graph:
+ transitions = gtn.Graph(calc_grad)
+ transitions.add_node(True, ngram == 1)
+
+ state_map = {(): 0}
+
+ # First build transitions which include <s>:
+ for n in range(1, ngram):
+ for state in itertools.product(range(num_tokens), repeat=n):
+ in_idx = state_map[state[:-1]]
+ out_idx = transitions.add_node(False, ngram == 1)
+ state_map[state] = out_idx
+ transitions.add_arc(in_idx, out_idx, state[-1])
+
+ for state in itertools.product(range(num_tokens), repeat=ngram):
+ state_idx = state_map[state[:-1]]
+ new_state_idx = state_map[state[1:]]
+ # p(state[-1] | state[:-1])
+ transitions.add_arc(state_idx, new_state_idx, state[-1])
+
+ if ngram > 1:
+ # Build transitions which include </s>:
+ end_idx = transitions.add_node(False, True)
+ for in_idx in range(end_idx):
+ transitions.add_arc(in_idx, end_idx, gtn.epsilon)
+
+ return transitions
+
+
+def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+ """Constructs a graph which transduces letters to word pieces."""
+ graph = gtn.Graph(False)
+ graph.add_node(True, True)
+ for i, wp in enumerate(word_pieces):
+ prev = 0
+ for l in wp[:-1]:
+ n = graph.add_node()
+ graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
+ prev = n
+ graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+ graph.arc_sort()
+ return graph
+
+
+def make_token_graph(
+ token_list: List, blank: str = "none", allow_repeats: bool = True
+) -> gtn.Graph:
+ """Constructs a graph with all the individual token transition models."""
+ if not allow_repeats and blank != "optional":
+ raise ValueError("Must use blank='optional' if disallowing repeats.")
+
+ ntoks = len(token_list)
+ graph = gtn.Graph(False)
+
+ # Creating nodes
+ graph.add_node(True, True)
+ for i in range(ntoks):
+ # We can consume one or more consecutive word
+ # pieces for each emission:
+ # E.g. [ab, ab, ab] transduces to [ab]
+ graph.add_node(False, blank != "forced")
+
+ if blank != "none":
+ graph.add_node()
+
+ # Creating arcs
+ if blank != "none":
+ # Blank index is assumed to be last (ntoks)
+ graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
+ graph.add_arc(ntoks + 1, 0, gtn.epsilon)
+
+ for i in range(ntoks):
+ graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
+ graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
+
+ if allow_repeats:
+ if blank == "forced":
+ # Allow transitions from token to blank only
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ else:
+ # Allow transition from token to blank and all other tokens
+ graph.add_arc(i + 1, 0, gtn.epsilon)
+
+ else:
+ # allow transitions to blank and all other tokens except the same token
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ for j in range(ntoks):
+ if i != j:
+ graph.add_arc(i + 1, j + 1, j, j)
+
+ return graph
+
+
+class TransducerLossFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ inputs,
+ targets,
+ tokens,
+ lexicon,
+ transition_params=None,
+ transitions=None,
+ reduction="none",
+ ) -> Tensor:
+ B, T, C = inputs.shape
+
+ losses = [None] * B
+ emissions_graphs = [None] * B
+
+ if transitions is not None:
+ if transition_params is None:
+ raise ValueError("Specified transitions, but not transition params.")
+
+ cpu_data = transition_params.cpu().contiguous()
+ transitions.set_weights(cpu_data.data_ptr())
+ transitions.calc_grad = transition_params.requires_grad
+ transitions.zero_grad()
+
+ def process(b: int) -> None:
+ # Create emission graph:
+ emissions = gtn.linear_graph(T, C, inputs.requires_grad)
+ cpu_data = inputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+ target = make_chain_graph(targets[b])
+ target.arc_sort(True)
+
+ # Create token tot grapheme decomposition graph
+ tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
+ tokens_target.arc_sort()
+
+ # Create alignment graph:
+ aligments = gtn.project_input(
+ gtn.remove(gtn.compose(tokens, tokens_target))
+ )
+ aligments.arc_sort()
+
+ # Add transitions scores:
+ if transitions is not None:
+ aligments = gtn.intersect(transitions, aligments)
+ aligments.arc_sort()
+
+ loss = gtn.forward_score(gtn.intersect(emissions, aligments))
+
+ # Normalize if needed:
+ if transitions is not None:
+ norm = gtn.forward_score(gtn.intersect(emissions, transitions))
+ loss = gtn.subtract(loss, norm)
+
+ losses[b] = gtn.negate(loss)
+
+ # Save for backward:
+ if emissions.calc_grad:
+ emissions_graphs[b] = emissions
+
+ gtn.parallel_for(process, range(B))
+
+ ctx.graphs = (losses, emissions_graphs, transitions)
+ ctx.input_shape = inputs.shape
+
+ # Optionally reduce by target length
+ if reduction == "mean":
+ scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
+ else:
+ scales = [1.0] * B
+
+ ctx.scales = scales
+
+ loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
+ return torch.mean(loss.to(inputs.device))
+
+ @staticmethod
+ def backward(ctx, grad_output) -> Tuple:
+ losses, emissions_graphs, transitions = ctx.graphs
+ scales = ctx.scales
+
+ B, T, C = ctx.input_shape
+ calc_emissions = ctx.needs_input_grad[0]
+ input_grad = torch.empty((B, T, C)) if calc_emissions else None
+
+ def process(b: int) -> None:
+ scale = make_scalar_graph(scales[b])
+ gtn.backward(losses[b], scale)
+ emissions = emissions_graphs[b]
+ if calc_emissions:
+ grad = emissions.grad().weights_to_numpy()
+ input_grad[b] = torch.tensor(grad).view(1, T, C)
+
+ gtn.parallel_for(process, range(B))
+
+ if calc_emissions:
+ input_grad = input_grad.to(grad_output.device)
+ input_grad *= grad_output / B
+
+ if ctx.needs_input_grad[4]:
+ grad = transitions.grad().weights_to_numpy()
+ transition_grad = torch.tensor(grad).to(grad_output.device)
+ transition_grad *= grad_output / B
+ else:
+ transition_grad = None
+
+ return (
+ input_grad,
+ None, # target
+ None, # tokens
+ None, # lexicon
+ transition_grad, # transition params
+ None, # transitions graph
+ None,
+ )
+
+
+TransducerLoss = TransducerLossFunction.apply
+
+
+class Transducer(nn.Module):
+ def __init__(
+ self,
+ tokens: List,
+ graphemes_to_idx: Dict,
+ ngram: int = 0,
+ transitions: str = None,
+ blank: str = "none",
+ allow_repeats: bool = True,
+ reduction: str = "none",
+ ) -> None:
+ """A generic transducer loss function.
+
+ Args:
+ tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
+ representing the output tokens of the model (e.g. letters,
+ word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
+ could be a list of sub-word tokens.
+ graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
+ "a", "b", ..) to their corresponding integer index.
+ ngram (int) : Order of the token-level transition model. If `ngram=0`
+ then no transition model is used.
+ blank (string) : Specifies the usage of blank token
+ 'none' - do not use blank token
+ 'optional' - allow an optional blank inbetween tokens
+ 'forced' - force a blank inbetween tokens (also referred to as garbage token)
+ allow_repeats (boolean) : If false, then we don't allow paths with
+ consecutive tokens in the alignment graph. This keeps the graph
+ unambiguous in the sense that the same input cannot transduce to
+ different outputs.
+ """
+ super().__init__()
+ if blank not in ["optional", "forced", "none"]:
+ raise ValueError(
+ "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
+ )
+ self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
+ self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+ self.ngram = ngram
+ if ngram > 0 and transitions is not None:
+ raise ValueError("Only one of ngram and transitions may be specified")
+
+ if ngram > 0:
+ transitions = make_transitions_graph(
+ ngram, len(tokens) + int(blank != "none"), True
+ )
+
+ if transitions is not None:
+ self.transitions = transitions
+ self.transitions.arc_sort()
+ self.transitions_params = nn.Parameter(
+ torch.zeros(self.transitions.num_arcs())
+ )
+ else:
+ self.transitions = None
+ self.transitions_params = None
+ self.reduction = reduction
+
+ def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
+ TransducerLoss(
+ inputs,
+ targets,
+ self.tokens,
+ self.lexicon,
+ self.transitions_params,
+ self.transitions,
+ self.reduction,
+ )
+
+ def viterbi(self, outputs: Tensor) -> List[Tensor]:
+ B, T, C = outputs.shape
+
+ if self.transitions is not None:
+ cpu_data = self.transition_params.cpu().contiguous()
+ self.transitions.set_weights(cpu_data.data_ptr())
+ self.transitions.calc_grad = False
+
+ self.tokens.arc_sort()
+
+ paths = [None] * B
+
+ def process(b: int) -> None:
+ emissions = gtn.linear_graph(T, C, False)
+ cpu_data = outputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+
+ if self.transitions is not None:
+ full_graph = gtn.intersect(emissions, self.transitions)
+ else:
+ full_graph = emissions
+
+ # Find the best path and remove back-off arcs:
+ path = gtn.remove(gtn.viterbi_path(full_graph))
+
+ # Left compose the viterbi path with the "aligment to token"
+ # transducer to get the outputs:
+ path = gtn.compose(path, self.tokens)
+
+ # When there are ambiguous paths (allow_repeats is true), we take
+ # the shortest:
+ path = gtn.viterbi_path(path)
+ path = gtn.remove(gtn.project_output(path))
+ paths[b] = path.labels_to_list()
+
+ gtn.parallel_for(process, range(B))
+ predictions = [torch.IntTensor(path) for path in paths]
+ return predictions
+
+
+def load_transducer_loss(
+ num_features: int,
+ ngram: int,
+ tokens: str,
+ lexicon: str,
+ transitions: str,
+ blank: str,
+ allow_repeats: bool,
+ prepend_wordsep: bool = False,
+ use_words: bool = False,
+ data_dir: Optional[Union[str, Path]] = None,
+ reduction: str = "mean",
+) -> Tuple[Transducer, int]:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ if transitions is not None:
+ transitions = gtn.load(str(processed_path / transitions))
+
+ preprocessor = Preprocessor(
+ data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ )
+
+ num_tokens = preprocessor.num_tokens
+
+ criterion = Transducer(
+ preprocessor.tokens,
+ preprocessor.graphemes_to_index,
+ ngram=ngram,
+ transitions=transitions,
+ blank=blank,
+ allow_repeats=allow_repeats,
+ reduction=reduction,
+ )
+
+ return criterion, num_tokens + int(blank != "none")