diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
commit | db86cef2d308f58325278061c6aa177a535e7e03 (patch) | |
tree | a013fa85816337269f9cdc5a8992813fa62d299d | |
parent | b980a281712a5b1ee7ee5bd8f5d4762cd91a070b (diff) |
Replace attr with attrs
-rw-r--r-- | poetry.lock | 235 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | text_recognizer/data/base_data_module.py | 32 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 12 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 18 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 8 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 13 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 24 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 10 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 24 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 32 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py | 71 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 24 |
17 files changed, 277 insertions, 246 deletions
diff --git a/poetry.lock b/poetry.lock index 7bf5385..3b1e6a1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,7 +2,7 @@ name = "absl-py" version = "1.0.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -13,7 +13,7 @@ six = "*" name = "aiohttp" version = "3.8.1" description = "Async http client/server framework (asyncio)" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -33,7 +33,7 @@ speedups = ["aiodns", "brotli", "cchardet"] name = "aiosignal" version = "1.2.0" description = "aiosignal: a list of registered asynchronous callbacks" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -44,7 +44,7 @@ frozenlist = ">=1.1.0" name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" -category = "main" +category = "dev" optional = false python-versions = "*" @@ -130,7 +130,7 @@ test = ["astroid", "pytest"] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -143,14 +143,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] -name = "attr" -version = "0.3.1" -description = "Simple decorator to set attributes of target function or class in a DRY way." -category = "main" -optional = false -python-versions = "*" - -[[package]] name = "attrs" version = "21.4.0" description = "Classes Without Boilerplate" @@ -159,9 +151,9 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.extras] -dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] -docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] -tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope-interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] +docs = ["furo", "sphinx", "zope-interface", "sphinx-notfound-page"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope-interface", "cloudpickle"] tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] [[package]] @@ -250,7 +242,6 @@ six = ">=1.9.0" webencodings = "*" [package.extras] -css = ["tinycss2 (>=1.1.0)"] dev = ["pip-tools (==6.5.1)", "pytest (==7.1.1)", "flake8 (==4.0.1)", "tox (==3.24.5)", "sphinx (==4.3.2)", "twine (==4.0.0)", "wheel (==0.37.1)", "hashin (==0.17.0)", "black (==22.3.0)", "mypy (==0.942)"] [[package]] @@ -263,9 +254,9 @@ python-versions = "*" [[package]] name = "cachetools" -version = "5.1.0" +version = "5.2.0" description = "Extensible memoizing collections and decorators" -category = "main" +category = "dev" optional = false python-versions = "~=3.7" @@ -273,7 +264,7 @@ python-versions = "~=3.7" name = "certifi" version = "2022.5.18.1" description = "Python package for providing Mozilla's CA Bundle." -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -292,7 +283,7 @@ pycparser = "*" name = "charset-normalizer" version = "2.0.12" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" +category = "dev" optional = false python-versions = ">=3.5.0" @@ -311,7 +302,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" name = "colorama" version = "0.4.4" description = "Cross-platform colored terminal text." -category = "main" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" @@ -408,7 +399,7 @@ python-versions = "*" name = "einops" version = "0.3.2" description = "A new flavour of deep learning operations" -category = "main" +category = "dev" optional = false python-versions = "*" @@ -528,6 +519,7 @@ python-versions = "*" [package.dependencies] pycodestyle = "*" +setuptools = "*" [[package]] name = "flake8-polyfill" @@ -566,7 +558,7 @@ woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"] name = "frozenlist" version = "1.3.0" description = "A list-like structure which implements collections.abc.MutableSequence" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -574,7 +566,7 @@ python-versions = ">=3.7" name = "fsspec" version = "2022.5.0" description = "File-system specification" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -631,7 +623,7 @@ gitdb = ">=4.0.1,<5" name = "google-auth" version = "2.6.6" description = "Google Authentication Library" -category = "main" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" @@ -650,7 +642,7 @@ reauth = ["pyu2f (>=0.1.5)"] name = "google-auth-oauthlib" version = "0.4.6" description = "Google Authentication Library" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -665,7 +657,7 @@ tool = ["click (>=6.0.0)"] name = "grpcio" version = "1.46.3" description = "HTTP/2-based RPC framework" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -690,7 +682,7 @@ numpy = ">=1.14.5" name = "hydra-core" version = "1.2.0" description = "A framework for elegantly configuring complex applications" -category = "main" +category = "dev" optional = false python-versions = "*" @@ -703,7 +695,7 @@ packaging = "*" name = "idna" version = "3.3" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" +category = "dev" optional = false python-versions = ">=3.5" @@ -711,7 +703,7 @@ python-versions = ">=3.5" name = "importlib-metadata" version = "4.11.4" description = "Read metadata from Python packages" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -721,7 +713,7 @@ zipp = ">=0.5" [package.extras] docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] perf = ["ipython"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl-flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] [[package]] name = "ipykernel" @@ -748,7 +740,7 @@ test = ["pytest (>=6.0)", "pytest-cov", "flaky", "ipyparallel", "pre-commit", "p [[package]] name = "ipython" -version = "8.3.0" +version = "8.4.0" description = "IPython: Productive Interactive Computing" category = "dev" optional = false @@ -765,6 +757,7 @@ pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} pickleshare = "*" prompt-toolkit = ">=2.0.0,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.1.0" pygments = ">=2.4.0" +setuptools = ">=18.5" stack-data = "*" traitlets = ">=5" @@ -1044,7 +1037,7 @@ python-versions = ">=3.7" name = "loguru" version = "0.6.0" description = "Python logging made (stupidly) simple" -category = "main" +category = "dev" optional = false python-versions = ">=3.5" @@ -1059,7 +1052,7 @@ dev = ["colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "tox (>=3. name = "markdown" version = "3.3.7" description = "Python implementation of Markdown." -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -1135,7 +1128,7 @@ python-versions = ">=3.5" name = "multidict" version = "6.0.2" description = "multidict implementation" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -1181,7 +1174,7 @@ test = ["pytest", "pytest-tornasync", "pytest-console-scripts"] [[package]] name = "nbclient" -version = "0.6.3" +version = "0.6.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." category = "dev" optional = false @@ -1191,7 +1184,7 @@ python-versions = ">=3.7.0" jupyter-client = ">=6.1.5" nbformat = ">=5.0" nest-asyncio = "*" -traitlets = ">=5.0.0" +traitlets = ">=5.2.2" [package.extras] sphinx = ["autodoc-traits", "mock", "moto", "myst-parser", "Sphinx (>=1.7)", "sphinx-book-theme"] @@ -1325,7 +1318,7 @@ test = ["pytest", "pytest-tornasync", "pytest-console-scripts"] name = "numpy" version = "1.22.4" description = "NumPy is the fundamental package for array computing with Python." -category = "main" +category = "dev" optional = false python-versions = ">=3.8" @@ -1333,7 +1326,7 @@ python-versions = ">=3.8" name = "oauthlib" version = "3.2.0" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -1344,9 +1337,9 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] [[package]] name = "omegaconf" -version = "2.2.1" +version = "2.2.2" description = "A flexible configuration library" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -1374,7 +1367,7 @@ numpy = [ name = "packaging" version = "21.3" description = "Core utilities for Python packages" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -1448,7 +1441,7 @@ python-versions = "*" name = "pillow" version = "9.1.1" description = "Python Imaging Library (Fork)" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -1507,7 +1500,7 @@ wcwidth = "*" name = "protobuf" version = "3.20.1" description = "Protocol Buffers" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -1553,7 +1546,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" name = "pyasn1" version = "0.4.8" description = "ASN.1 types and codecs" -category = "main" +category = "dev" optional = false python-versions = "*" @@ -1561,7 +1554,7 @@ python-versions = "*" name = "pyasn1-modules" version = "0.2.8" description = "A collection of ASN.1-based protocols modules." -category = "main" +category = "dev" optional = false python-versions = "*" @@ -1588,7 +1581,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" name = "pydeprecate" version = "0.3.2" description = "Deprecation tooling" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -1626,7 +1619,7 @@ python-versions = ">=3.6" name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -category = "main" +category = "dev" optional = false python-versions = ">=3.6.8" @@ -1706,9 +1699,9 @@ six = ">=1.5" [[package]] name = "pytorch-lightning" -version = "1.6.3" +version = "1.6.4" description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -1716,7 +1709,8 @@ python-versions = ">=3.7" fsspec = {version = ">=2021.05.0,<2021.06.0 || >2021.06.0", extras = ["http"]} numpy = ">=1.17.2" packaging = ">=17.0" -pyDeprecate = ">=0.3.1,<0.4.0" +protobuf = "<=3.20.1" +pyDeprecate = ">=0.3.1" PyYAML = ">=5.4" tensorboard = ">=2.2.0" torch = ">=1.8" @@ -1725,14 +1719,17 @@ tqdm = ">=4.57.0" typing-extensions = ">=4.0.0" [package.extras] -all = ["matplotlib (>3.1)", "horovod (>=0.21.2,!=0.24.0)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"] -cpu = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"] -cpu-extra = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)"] -dev = ["matplotlib (>3.1)", "horovod (>=0.21.2,!=0.24.0)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"] +all = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.*)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython", "fairscale (>=0.4.5)", "deepspeed", "horovod (>=0.21.2,!=0.24.0)", "hivemind (>=1.0.1)"] +deepspeed = ["deepspeed"] +dev = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.*)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"] examples = ["torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"] -extra = ["matplotlib (>3.1)", "horovod (>=0.21.2,!=0.24.0)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)"] +extra = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.*)"] +fairscale = ["fairscale (>=0.4.5)"] +hivemind = ["hivemind (>=1.0.1)"] +horovod = ["horovod (>=0.21.2,!=0.24.0)"] loggers = ["neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)"] -test = ["coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"] +strategies = ["fairscale (>=0.4.5)", "deepspeed", "horovod (>=0.21.2,!=0.24.0)", "hivemind (>=1.0.1)"] +test = ["coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"] [[package]] name = "pytz" @@ -1762,7 +1759,7 @@ python-versions = ">=3.7" name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -1826,7 +1823,7 @@ python-versions = ">=3.6" name = "requests" version = "2.27.1" description = "Python HTTP for Humans." -category = "main" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" @@ -1844,7 +1841,7 @@ use_chardet_on_py3 = ["chardet (>=3.0.2,<5)"] name = "requests-oauthlib" version = "1.3.1" description = "OAuthlib authentication support for Requests." -category = "main" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" @@ -1859,7 +1856,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "rsa" version = "4.8" description = "Pure-Python RSA implementation" -category = "main" +category = "dev" optional = false python-versions = ">=3.6,<4" @@ -1879,6 +1876,7 @@ Click = ">=6.0" dparse = ">=0.5.1" packaging = "*" requests = "*" +setuptools = "*" [[package]] name = "scipy" @@ -1946,6 +1944,18 @@ python-versions = ">=3.6" test = ["pytest"] [[package]] +name = "setuptools" +version = "59.5.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "sphinx-inline-tabs", "sphinxcontrib-towncrier", "furo"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "mock", "flake8-2020", "virtualenv (>=13.0.0)", "pytest-virtualenv (>=1.2.7)", "wheel", "paver", "pip (>=19.1)", "jaraco.envs (>=2.2)", "pytest-xdist", "sphinx", "jaraco.path (>=3.2.0)", "pytest-black (>=0.3.7)", "pytest-mypy"] + +[[package]] name = "setuptools-scm" version = "6.4.2" description = "the blessed package to manage your versions by scm tags" @@ -1955,6 +1965,7 @@ python-versions = ">=3.6" [package.dependencies] packaging = ">=20.0" +setuptools = "*" tomli = ">=1.0.0" [package.extras] @@ -1973,7 +1984,7 @@ python-versions = ">=3.5" name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" @@ -1981,7 +1992,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" name = "smart-open" version = "5.2.1" description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" -category = "main" +category = "dev" optional = false python-versions = ">=3.6,<4.0" @@ -2057,7 +2068,7 @@ pbr = ">=2.0.0,<2.1.0 || >2.1.0" name = "tensorboard" version = "2.9.0" description = "TensorBoard lets you watch Tensors Flow" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -2070,15 +2081,17 @@ markdown = ">=2.6.8" numpy = ">=1.12.0" protobuf = ">=3.9.2" requests = ">=2.21.0,<3" +setuptools = ">=41.0.0" tensorboard-data-server = ">=0.6.0,<0.7.0" tensorboard-plugin-wit = ">=1.6.0" werkzeug = ">=1.0.1" +wheel = ">=0.26" [[package]] name = "tensorboard-data-server" version = "0.6.1" description = "Fast data loading for TensorBoard" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -2086,7 +2099,7 @@ python-versions = ">=3.6" name = "tensorboard-plugin-wit" version = "1.8.1" description = "What-If Tool TensorBoard plugin." -category = "main" +category = "dev" optional = false python-versions = "*" @@ -2141,7 +2154,7 @@ python-versions = ">=3.7" name = "torch" version = "1.11.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "main" +category = "dev" optional = false python-versions = ">=3.7.0" @@ -2150,7 +2163,7 @@ typing-extensions = "*" [[package]] name = "torchinfo" -version = "1.6.6" +version = "1.7.0" description = "Model summary in PyTorch, based off of the original torchsummary." category = "dev" optional = false @@ -2160,7 +2173,7 @@ python-versions = ">=3.7" name = "torchmetrics" version = "0.4.1" description = "PyTorch native Metrics" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -2176,7 +2189,7 @@ image = ["scipy", "torchvision", "torch-fidelity"] name = "torchvision" version = "0.12.0" description = "image and video datasets and models for torch deep learning" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -2202,7 +2215,7 @@ python-versions = ">= 3.5" name = "tqdm" version = "4.64.0" description = "Fast, Extensible Progress Meter" -category = "main" +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" @@ -2217,7 +2230,7 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.2.1.post0" +version = "5.2.2.post1" description = "" category = "dev" optional = false @@ -2250,7 +2263,7 @@ test = ["pytest", "typing-extensions", "mypy"] name = "typing-extensions" version = "4.2.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -2258,7 +2271,7 @@ python-versions = ">=3.7" name = "urllib3" version = "1.26.9" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" @@ -2269,7 +2282,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] name = "wandb" -version = "0.12.16" +version = "0.12.17" description = "A CLI and library for interacting with the Weights and Biases API." category = "dev" optional = false @@ -2281,13 +2294,14 @@ docker-pycreds = ">=0.4.0" GitPython = ">=1.0.0" pathtools = "*" promise = ">=2.0,<3" -protobuf = ">=3.12.0" +protobuf = ">=3.12.0,<4.0dev" psutil = ">=5.0.0" python-dateutil = ">=2.6.1" PyYAML = "*" requests = ">=2.0.0,<3" sentry-sdk = ">=1.0.0" setproctitle = "*" +setuptools = "*" shortuuid = ">=0.5.0" six = ">=1.13.0" @@ -2334,7 +2348,7 @@ test = ["websockets"] name = "werkzeug" version = "2.1.2" description = "The comprehensive WSGI web application library." -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -2342,6 +2356,17 @@ python-versions = ">=3.7" watchdog = ["watchdog"] [[package]] +name = "wheel" +version = "0.37.1" +description = "A built-package format for Python" +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[package.extras] +test = ["pytest (>=3.0.0)", "pytest-cov"] + +[[package]] name = "widgetsnbextension" version = "3.6.0" description = "IPython HTML widgets for Jupyter" @@ -2356,7 +2381,7 @@ notebook = ">=4.4.1" name = "win32-setctime" version = "1.1.0" description = "A small Python utility to set file creation time on Windows" -category = "main" +category = "dev" optional = false python-versions = ">=3.5" @@ -2367,7 +2392,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"] name = "yarl" version = "1.7.2" description = "Yet another URL library" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -2379,18 +2404,18 @@ multidict = ">=4.0" name = "zipp" version = "3.8.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" [package.extras] docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco-itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "608eb05815709f89d2b622af655b68624f43064413bb4cd6e5d681029bbb76c7" +content-hash = "b0c8b83118a86644eb9ceb5e33f9944d13016592a849111163a853dca3adb2aa" [metadata.files] absl-py = [ @@ -2529,10 +2554,6 @@ atomicwrites = [ {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, ] -attr = [ - {file = "attr-0.3.1-py2-none-any.whl", hash = "sha256:0b1aaddb85bd9e9c4bd75092f4440d6616ff40b0df0437f00771871670f7c9fd"}, - {file = "attr-0.3.1.tar.gz", hash = "sha256:9091548058d17f132596e61fa7518e504f76b9a4c61ca7d86e1f96dbf7d4775d"}, -] attrs = [ {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, @@ -2566,8 +2587,8 @@ boltons = [ {file = "boltons-20.2.1.tar.gz", hash = "sha256:dd362291a460cc1e0c2e91cc6a60da3036ced77099b623112e8f833e6734bdc5"}, ] cachetools = [ - {file = "cachetools-5.1.0-py3-none-any.whl", hash = "sha256:4ebbd38701cdfd3603d1f751d851ed248ab4570929f2d8a7ce69e30c420b141c"}, - {file = "cachetools-5.1.0.tar.gz", hash = "sha256:8b3b8fa53f564762e5b221e9896798951e7f915513abf2ba072ce0f07f3f5a98"}, + {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"}, + {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"}, ] certifi = [ {file = "certifi-2022.5.18.1-py3-none-any.whl", hash = "sha256:f1d53542ee8cbedbe2118b5686372fb33c297fcd6379b050cca0ef13a597382a"}, @@ -2994,8 +3015,8 @@ ipykernel = [ {file = "ipykernel-6.13.0.tar.gz", hash = "sha256:0e28273e290858393e86e152b104e5506a79c13d25b951ac6eca220051b4be60"}, ] ipython = [ - {file = "ipython-8.3.0-py3-none-any.whl", hash = "sha256:341456643a764c28f670409bbd5d2518f9b82c013441084ff2c2fc999698f83b"}, - {file = "ipython-8.3.0.tar.gz", hash = "sha256:807ae3cf43b84693c9272f70368440a9a7eaa2e7e6882dad943c32fbf7e51402"}, + {file = "ipython-8.4.0-py3-none-any.whl", hash = "sha256:7ca74052a38fa25fe9bedf52da0be7d3fdd2fb027c3b778ea78dfe8c212937d1"}, + {file = "ipython-8.4.0.tar.gz", hash = "sha256:f2db3a10254241d9b447232cec8b424847f338d9d36f9a577a6192c332a46abd"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -3295,8 +3316,8 @@ nbclassic = [ {file = "nbclassic-0.3.7.tar.gz", hash = "sha256:36dbaa88ffaf5dc05d149deb97504b86ba648f4a80a60b8a58ac94acab2daeb5"}, ] nbclient = [ - {file = "nbclient-0.6.3-py3-none-any.whl", hash = "sha256:2747ac9b385720d8a6c34f2f71e72cbe64aec6cadaadcc064a4df0b0e99c5874"}, - {file = "nbclient-0.6.3.tar.gz", hash = "sha256:b80726fc1fb89a0e8f8be1e77e28d0026b1e8ed90bc143c8a0c7622e4f8cdd9e"}, + {file = "nbclient-0.6.4-py3-none-any.whl", hash = "sha256:f251bba200a2b401a061dfd700a7a70b5772f664fb49d4a2d3e5536ec0e98c76"}, + {file = "nbclient-0.6.4.tar.gz", hash = "sha256:cdef7757cead1735d2c70cc66095b072dced8a1e6d1c7639ef90cd3e04a11f2e"}, ] nbconvert = [ {file = "nbconvert-6.5.0-py3-none-any.whl", hash = "sha256:c56dd0b8978a1811a5654f74c727ff16ca87dd5a43abd435a1c49b840fcd8360"}, @@ -3351,8 +3372,8 @@ oauthlib = [ {file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"}, ] omegaconf = [ - {file = "omegaconf-2.2.1-py3-none-any.whl", hash = "sha256:5ce512b0a8996b5acddc7b30f6bffa337a4b0ada1d96b4270588365e2a69e6d5"}, - {file = "omegaconf-2.2.1.tar.gz", hash = "sha256:6796a5b51a0112410d9940c3a5d2938d5e644377a5517c02db1aef99e03b8af2"}, + {file = "omegaconf-2.2.2-py3-none-any.whl", hash = "sha256:556917181487fb66fe832d3c7b324f51b2f4c8adc373dd5091be921501b7d420"}, + {file = "omegaconf-2.2.2.tar.gz", hash = "sha256:65c85b2a84669a570c70f2df00de3cebcd9b47a8587d3c53b1aa5766bb096f77"}, ] opencv-python = [ {file = "opencv-python-4.5.5.64.tar.gz", hash = "sha256:f65de0446a330c3b773cd04ba10345d8ce1b15dcac3f49770204e37602d0b3f7"}, @@ -3619,8 +3640,8 @@ python-dateutil = [ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] pytorch-lightning = [ - {file = "pytorch-lightning-1.6.3.tar.gz", hash = "sha256:beb1f36a6dae91f5fef0959a04af1092dff4f3f4d99c20f0e033f84e615903e3"}, - {file = "pytorch_lightning-1.6.3-py3-none-any.whl", hash = "sha256:5419adaee5bb8057b1dad69d2cbb79f823f54a94bb67cb47fe75cdf8c1bc5616"}, + {file = "pytorch-lightning-1.6.4.tar.gz", hash = "sha256:5459f2c3e67676ec59e94576d1499e9559d214e7df41eadd135db64b4ccf54b9"}, + {file = "pytorch_lightning-1.6.4-py3-none-any.whl", hash = "sha256:0f42f93116a3fcb6fd8c9ea45cf7c918e4aa3f848ae21d0e9ac2bf39f2865dd7"}, ] pytz = [ {file = "pytz-2022.1-py2.py3-none-any.whl", hash = "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c"}, @@ -3945,6 +3966,10 @@ setproctitle = [ {file = "setproctitle-1.2.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:97accd117392b1e57e09888792750c403d7729b7e4b193005178b3736b325ea0"}, {file = "setproctitle-1.2.3.tar.gz", hash = "sha256:ecf28b1c07a799d76f4326e508157b71aeda07b84b90368ea451c0710dbd32c0"}, ] +setuptools = [ + {file = "setuptools-59.5.0-py3-none-any.whl", hash = "sha256:6d10741ff20b89cd8c6a536ee9dc90d3002dec0226c78fb98605bfb9ef8a7adf"}, + {file = "setuptools-59.5.0.tar.gz", hash = "sha256:d144f85102f999444d06f9c0e8c737fd0194f10f2f7e5fdb77573f6e2fa4fad0"}, +] setuptools-scm = [ {file = "setuptools_scm-6.4.2-py3-none-any.whl", hash = "sha256:acea13255093849de7ccb11af9e1fb8bde7067783450cee9ef7a93139bddf6d4"}, {file = "setuptools_scm-6.4.2.tar.gz", hash = "sha256:6833ac65c6ed9711a4d5d2266f8024cfa07c533a0e55f4c12f6eff280a5a9e30"}, @@ -4034,8 +4059,8 @@ torch = [ {file = "torch-1.11.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:0e48af66ad755f0f9c5f2664028a414f57c49d6adc37e77e06fe0004da4edb61"}, ] torchinfo = [ - {file = "torchinfo-1.6.6-py3-none-any.whl", hash = "sha256:406965434fd768e1ef5c373be6e73fc0e91b0f54a8818165b46297bfc493a8ad"}, - {file = "torchinfo-1.6.6.tar.gz", hash = "sha256:8eb0a38c7dc2403b8f1eab101eb7cca012fb2eaf33caf248acb67c996836b70e"}, + {file = "torchinfo-1.7.0-py3-none-any.whl", hash = "sha256:23d8771d965e50cdd242327ea4669b20978db0235da46ad7889be1b321598fea"}, + {file = "torchinfo-1.7.0.tar.gz", hash = "sha256:24cf949cc65d3926638e845e0aa949feb65f3025c58a8b8e084969bd42e30c0b"}, ] torchmetrics = [ {file = "torchmetrics-0.4.1-py3-none-any.whl", hash = "sha256:70c83f0fc804a4fe00a9e72dbd2960ff76e39ef62570a19bbdce0c15a1ee0d71"}, @@ -4110,8 +4135,8 @@ tqdm = [ {file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"}, ] traitlets = [ - {file = "traitlets-5.2.1.post0-py3-none-any.whl", hash = "sha256:f44b708d33d98b0addb40c29d148a761f44af740603a8fd0e2f8b5b27cf0f087"}, - {file = "traitlets-5.2.1.post0.tar.gz", hash = "sha256:70815ecb20ec619d1af28910ade523383be13754283aef90528eb3d47b77c5db"}, + {file = "traitlets-5.2.2.post1-py3-none-any.whl", hash = "sha256:1530d04badddc6a73d50b7ee34667d4b96914da352109117b4280cb56523a51b"}, + {file = "traitlets-5.2.2.post1.tar.gz", hash = "sha256:74803a1baa59af70f023671d86d5c7a834c931186df26d50d362ee6a1ff021fd"}, ] typed-ast = [ {file = "typed_ast-1.4.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:2068531575a125b87a41802130fa7e29f26c09a2833fea68d9a40cf33902eba6"}, @@ -4158,8 +4183,8 @@ urllib3 = [ {file = "urllib3-1.26.9.tar.gz", hash = "sha256:aabaf16477806a5e1dd19aa41f8c2b7950dd3c746362d7e3223dbe6de6ac448e"}, ] wandb = [ - {file = "wandb-0.12.16-py2.py3-none-any.whl", hash = "sha256:ed7782dadfb5bc457998eccd995f88ae564cdf2a36b12024e4a5d9a47b1b84e8"}, - {file = "wandb-0.12.16.tar.gz", hash = "sha256:a738b5eb61081fa96fc2e16ffaf6dbde67b78f973ff45bda61ed93659ca09912"}, + {file = "wandb-0.12.17-py2.py3-none-any.whl", hash = "sha256:40e599ed7a4a633a4e1da77d026ee872fcb60a207aafbd1bf8ec1ab5b8171ccf"}, + {file = "wandb-0.12.17.tar.gz", hash = "sha256:ad2fe5a9cbb44c445cb0cfc6d04804f74dca2999ae98f0c8db93721b522f76f1"}, ] wcwidth = [ {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, @@ -4177,6 +4202,10 @@ werkzeug = [ {file = "Werkzeug-2.1.2-py3-none-any.whl", hash = "sha256:72a4b735692dd3135217911cbeaa1be5fa3f62bffb8745c5215420a03dc55255"}, {file = "Werkzeug-2.1.2.tar.gz", hash = "sha256:1ce08e8093ed67d638d63879fd1ba3735817f7a80de3674d293f5984f25fb6e6"}, ] +wheel = [ + {file = "wheel-0.37.1-py2.py3-none-any.whl", hash = "sha256:4bdcd7d840138086126cd09254dc6195fb4fc6f01c050a1d7236f2630db1d22a"}, + {file = "wheel-0.37.1.tar.gz", hash = "sha256:e9a504e793efbca1b8e0e9cb979a249cf4a0a7b5b8c9e8b65a5e39d49529c1c4"}, +] widgetsnbextension = [ {file = "widgetsnbextension-3.6.0-py2.py3-none-any.whl", hash = "sha256:4fd321cad39fdcf8a8e248a657202d42917ada8e8ed5dd3f60f073e0d54ceabd"}, {file = "widgetsnbextension-3.6.0.tar.gz", hash = "sha256:e84a7a9fcb9baf3d57106e184a7389a8f8eb935bf741a5eb9d60aa18cc029a80"}, diff --git a/pyproject.toml b/pyproject.toml index 2e7cf28..2f08ace 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,10 @@ omegaconf = "^2.1.0" einops = "^0.3.0" pytorch-lightning = "^1.6.3" hydra-core = "^1.1.1" -attr = "^0.3.1" smart-open = "^5.2.1" torch = "^1.11.0" torchvision = "^0.12.0" +attrs = "^21.4.0" [tool.poetry.dev-dependencies] pytest = "^5.4.2" diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 15c286a..a0c8416 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -import attr +from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +20,29 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@attr.s(repr=False) +@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_pre_init__(self) -> None: + def __attrs_post_init__(self) -> None: """Pre init constructor.""" super().__init__() - mapping: Type[AbstractMapping] = attr.ib() - transform: Optional[Callable] = attr.ib(default=None) - test_transform: Optional[Callable] = attr.ib(default=None) - target_transform: Optional[Callable] = attr.ib(default=None) - train_fraction: float = attr.ib(default=0.8) - batch_size: int = attr.ib(default=16) - num_workers: int = attr.ib(default=0) - pin_memory: bool = attr.ib(default=True) + mapping: Type[AbstractMapping] = field() + transform: Optional[Callable] = field(default=None) + test_transform: Optional[Callable] = field(default=None) + target_transform: Optional[Callable] = field(default=None) + train_fraction: float = field(default=0.8) + batch_size: int = field(default=16) + num_workers: int = field(default=0) + pin_memory: bool = field(default=True) # Placeholders - data_train: BaseDataset = attr.ib(init=False, default=None) - data_val: BaseDataset = attr.ib(init=False, default=None) - data_test: BaseDataset = attr.ib(init=False, default=None) - dims: Tuple[int, ...] = attr.ib(init=False, default=None) - output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) + data_train: BaseDataset = field(init=False, default=None) + data_val: BaseDataset = field(init=False, default=None) + data_test: BaseDataset = field(init=False, default=None) + dims: Tuple[int, ...] = field(init=False, default=None) + output_dims: Tuple[int, ...] = field(init=False, default=None) @classmethod def data_dirname(cls: T) -> Path: diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index b9567c7..c57cbcc 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,7 +1,7 @@ """Base PyTorch Dataset class.""" from typing import Callable, Dict, Optional, Sequence, Tuple, Union -import attr +from attrs import define, field import torch from torch import Tensor from torch.utils.data import Dataset @@ -9,7 +9,7 @@ from torch.utils.data import Dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file -@attr.s +@define class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -21,10 +21,10 @@ class BaseDataset(Dataset): target transforms. """ - data: Union[Sequence, Tensor] = attr.ib() - targets: Union[Sequence, Tensor] = attr.ib() - transform: Union[Optional[Callable], str] = attr.ib(default=None) - target_transform: Union[Optional[Callable], str] = attr.ib(default=None) + data: Union[Sequence, Tensor] = field() + targets: Union[Sequence, Tensor] = field() + transform: Union[Optional[Callable], str] = field(default=None) + target_transform: Union[Optional[Callable], str] = field(default=None) def __attrs_pre_init__(self) -> None: """Pre init constructor.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index dc8d31a..94882bf 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -6,7 +6,7 @@ import shutil from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile -import attr +from attrs import define import h5py from loguru import logger as log import numpy as np @@ -35,7 +35,7 @@ ESSENTIALS_FILENAME = ( ) -@attr.s(auto_attribs=True) +@define(auto_attribs=True) class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index c267286..43d55b9 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -3,7 +3,7 @@ from collections import defaultdict from pathlib import Path from typing import DefaultDict, List, Tuple -import attr +from attrs import define, field import h5py from loguru import logger as log import numpy as np @@ -33,17 +33,17 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - max_length: int = attr.ib(default=128) - min_overlap: float = attr.ib(default=0.0) - max_overlap: float = attr.ib(default=0.33) - num_train: int = attr.ib(default=10_000) - num_val: int = attr.ib(default=2_000) - num_test: int = attr.ib(default=2_000) - emnist: EMNIST = attr.ib(init=False, default=None) + max_length: int = field(default=128) + min_overlap: float = field(default=0.0) + max_overlap: float = field(default=0.33) + num_train: int = field(default=10_000) + num_val: int = field(default=2_000) + num_test: int = field(default=2_000) + emnist: EMNIST = field(init=False, default=None) def __attrs_post_init__(self) -> None: """Post init constructor.""" diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 766f3e0..8166863 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile -import attr +from attrs import define, field from boltons.cacheutils import cachedproperty from loguru import logger as log import toml @@ -27,7 +27,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. -@attr.s(auto_attribs=True) +@define(auto_attribs=True) class IAM(BaseDataModule): r"""The IAM Lines dataset. @@ -44,7 +44,7 @@ class IAM(BaseDataModule): contributed to one set only. """ - metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME)) + metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME)) def prepare_data(self) -> None: """Prepares the IAM dataset.""" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 22d00f1..52c10c3 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,5 +1,5 @@ """IAM original and sythetic dataset class.""" -import attr +from attrs import define, field from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -8,7 +8,7 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.load_transform import load_transform_from_file -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): """A dataset with synthetic and real handwritten paragraph.""" diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index a79c202..34cf605 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,7 +7,7 @@ import json from pathlib import Path from typing import List, Sequence, Tuple -import attr +from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageFile, ImageOps @@ -35,14 +35,14 @@ MAX_LABEL_LENGTH = 89 MAX_WORD_PIECE_LENGTH = 72 -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - dims: Tuple[int, int, int] = attr.ib( + dims: Tuple[int, int, int] = field( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) - output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) + output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 033b93e..b605bbc 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,7 +3,7 @@ import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple -import attr +from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageOps @@ -33,15 +33,15 @@ MAX_LABEL_LENGTH = 682 MAX_WORD_PIECE_LENGTH = 451 -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" # Placeholders - dims: Tuple[int, int, int] = attr.ib( + dims: Tuple[int, int, int] = field( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) - output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) + output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) def prepare_data(self) -> None: """Create data for training/testing.""" @@ -86,7 +86,10 @@ class IAMParagraphs(BaseDataModule): length=self.output_dims[0], ) return BaseDataset( - data, targets, transform=transform, target_transform=target_transform, + data, + targets, + transform=transform, + target_transform=target_transform, ) log.info(f"Loading IAM paragraph regions and lines for {stage}...") diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index ea59098..7143951 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,7 +2,7 @@ import random from typing import Any, List, Sequence, Tuple -import attr +from attrs import define from loguru import logger as log import numpy as np from PIL import Image @@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = ( ) -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 821cb69..bf3bc08 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,7 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -import attr +from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -14,7 +14,7 @@ import torchmetrics from text_recognizer.data.mappings.base import AbstractMapping -@attr.s(eq=False) +@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" @@ -22,22 +22,18 @@ class BaseLitModel(LightningModule): """Pre init constructor.""" super().__init__() - network: Type[nn.Module] = attr.ib() - loss_fn: Type[nn.Module] = attr.ib() - optimizer_configs: DictConfig = attr.ib() - lr_scheduler_configs: Optional[DictConfig] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + network: Type[nn.Module] = field() + loss_fn: Type[nn.Module] = field() + optimizer_configs: DictConfig = field() + lr_scheduler_configs: Optional[DictConfig] = field() + mapping: Type[AbstractMapping] = field() # Placeholders - train_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - test_acc: torchmetrics.Accuracy = attr.ib( + train_acc: torchmetrics.Accuracy = field( init=False, default=torchmetrics.Accuracy() ) + val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) def optimizer_zero_grad( self, diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index f83c9e4..e59a830 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,20 +1,20 @@ """Character Error Rate (CER).""" from typing import Set -import attr +from attrs import define, field import editdistance import torch from torch import Tensor from torchmetrics import Metric -@attr.s(eq=False) +@define(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set[Tensor] = attr.ib(converter=set) - error: Tensor = attr.ib(init=False) - total: Tensor = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(converter=set) + error: Tensor = field(init=False) + total: Tensor = field(init=False) def __attrs_post_init__(self) -> None: super().__init__() diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7272f46..c5120fe 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,7 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -import attr +from attrs import define, field import torch from torch import Tensor @@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = attr.ib(default=451) - start_token: str = attr.ib(default="<s>") - end_token: str = attr.ib(default="<e>") - pad_token: str = attr.ib(default="<p>") + max_output_len: int = field(default=451) + start_token: str = field(default="<s>") + end_token: str = field(default="<e>") + pad_token: str = field(default="<p>") - start_index: int = attr.ib(init=False) - end_index: int = attr.ib(init=False) - pad_index: int = attr.ib(init=False) + start_index: int = field(init=False) + end_index: int = field(init=False) + pad_index: int = field(init=False) - ignore_indices: Set[Tensor] = attr.ib(init=False) - val_cer: CharacterErrorRate = attr.ib(init=False) - test_cer: CharacterErrorRate = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(init=False) + val_cer: CharacterErrorRate = field(init=False) + test_cer: CharacterErrorRate = field(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 4c9ed75..cf64bcf 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,7 @@ """Efficientnet backbone.""" from typing import Tuple -import attr +from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@attr.s(eq=False) +@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" @@ -33,28 +33,28 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = attr.ib() - params: Tuple[float, float, float] = attr.ib(default=None, init=False) - stochastic_dropout_rate: float = attr.ib(default=0.2) - bn_momentum: float = attr.ib(default=0.99) - bn_eps: float = attr.ib(default=1.0e-3) - depth: int = attr.ib(default=7) - out_channels: int = attr.ib(default=None, init=False) - _conv_stem: nn.Sequential = attr.ib(default=None, init=False) - _blocks: nn.ModuleList = attr.ib(default=None, init=False) - _conv_head: nn.Sequential = attr.ib(default=None, init=False) + arch: str = field() + params: Tuple[float, float, float] = field(default=None, init=False) + stochastic_dropout_rate: float = field(default=0.2) + bn_momentum: float = field(default=0.99) + bn_eps: float = field(default=1.0e-3) + depth: int = field(default=7) + out_channels: int = field(default=None, init=False) + _conv_stem: nn.Sequential = field(default=None, init=False) + _blocks: nn.ModuleList = field(default=None, init=False) + _conv_head: nn.Sequential = field(default=None, init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self._build() @depth.validator - def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_depth(self, attribute, value: str) -> None: if not 5 <= value <= 7: raise ValueError(f"Depth has to be between 5 and 7, was: {value}") @arch.validator - def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_arch(self, attribute, value: str) -> None: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") @@ -88,7 +88,9 @@ class EfficientNet(nn.Module): for _ in range(num_repeats): self._blocks.append( MBConvBlock( - **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, + **args, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, ) ) args.in_channels = args.out_channels diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index beb7d57..98e9353 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,7 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -import attr +from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@attr.s(eq=False) +@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - block: nn.Sequential = attr.ib(init=False) + bn_momentum: float = field() + bn_eps: float = field() + block: nn.Sequential = field(init=False) def __attrs_pre_init__(self) -> None: super().__init__() @@ -36,12 +36,12 @@ class BaseModule(nn.Module): return self.block(x) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = attr.ib() - kernel_size: int = attr.ib() - stride: int = attr.ib() + channels: int = field() + kernel_size: int = field() + stride: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +85,13 @@ class Depthwise(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = attr.ib() - channels: int = attr.ib() - se_ratio: float = attr.ib() + in_channels: int = field() + channels: int = field() + se_ratio: float = field() def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -133,32 +133,35 @@ class Pointwise(BaseModule): ) -@attr.s(eq=False) +@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" def __attrs_pre_init__(self) -> None: super().__init__() - in_channels: int = attr.ib() - out_channels: int = attr.ib() - kernel_size: Tuple[int, int] = attr.ib() - stride: Tuple[int, int] = attr.ib(converter=_convert_stride) - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - se_ratio: float = attr.ib() - expand_ratio: int = attr.ib() - pad: Tuple[int, int, int, int] = attr.ib(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) - _depthwise: nn.Sequential = attr.ib(init=False) - _squeeze_excite: nn.Sequential = attr.ib(init=False) - _pointwise: nn.Sequential = attr.ib(init=False) + in_channels: int = field() + out_channels: int = field() + kernel_size: Tuple[int, int] = field() + stride: Tuple[int, int] = field(converter=_convert_stride) + bn_momentum: float = field() + bn_eps: float = field() + se_ratio: float = field() + expand_ratio: int = field() + pad: Tuple[int, int, int, int] = field(init=False) + _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) + _depthwise: nn.Sequential = field(init=False) + _squeeze_excite: nn.Sequential = field(init=False) + _pointwise: nn.Sequential = field(init=False) @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): - return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 return ((self.kernel_size - 1) // 2,) * 4 def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 87792a9..aa15b88 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,7 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -import attr +from attrs import define, field from einops import rearrange import torch from torch import einsum @@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import ( ) -@attr.s(eq=False) +@define(eq=False) class Attention(nn.Module): """Standard attention.""" def __attrs_pre_init__(self) -> None: super().__init__() - dim: int = attr.ib() - num_heads: int = attr.ib() - causal: bool = attr.ib(default=False) - dim_head: int = attr.ib(default=64) - dropout_rate: float = attr.ib(default=0.0) - rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) - scale: float = attr.ib(init=False) - dropout: nn.Dropout = attr.ib(init=False) - fc: nn.Linear = attr.ib(init=False) + dim: int = field() + num_heads: int = field() + causal: bool = field(default=False) + dim_head: int = field(default=64) + dropout_rate: float = field(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = field(default=None) + scale: float = field(init=False) + dropout: nn.Dropout = field(init=False) + fc: nn.Linear = field(init=False) def __attrs_post_init__(self) -> None: self.scale = self.dim ** -0.5 @@ -120,7 +120,6 @@ def apply_input_mask( input_mask = q_mask * k_mask energy = energy.masked_fill_(~input_mask, mask_value) - del input_mask return energy @@ -133,5 +132,4 @@ def apply_causal_mask( mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) energy.masked_fill_(mask, mask_value) - del mask return energy |