summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--poetry.lock235
-rw-r--r--pyproject.toml2
-rw-r--r--text_recognizer/data/base_data_module.py32
-rw-r--r--text_recognizer/data/base_dataset.py12
-rw-r--r--text_recognizer/data/emnist.py4
-rw-r--r--text_recognizer/data/emnist_lines.py18
-rw-r--r--text_recognizer/data/iam.py6
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py4
-rw-r--r--text_recognizer/data/iam_lines.py8
-rw-r--r--text_recognizer/data/iam_paragraphs.py13
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
-rw-r--r--text_recognizer/models/base.py24
-rw-r--r--text_recognizer/models/metrics.py10
-rw-r--r--text_recognizer/models/transformer.py24
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py32
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py71
-rw-r--r--text_recognizer/networks/transformer/attention.py24
17 files changed, 277 insertions, 246 deletions
diff --git a/poetry.lock b/poetry.lock
index 7bf5385..3b1e6a1 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2,7 +2,7 @@
name = "absl-py"
version = "1.0.0"
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -13,7 +13,7 @@ six = "*"
name = "aiohttp"
version = "3.8.1"
description = "Async http client/server framework (asyncio)"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -33,7 +33,7 @@ speedups = ["aiodns", "brotli", "cchardet"]
name = "aiosignal"
version = "1.2.0"
description = "aiosignal: a list of registered asynchronous callbacks"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -44,7 +44,7 @@ frozenlist = ">=1.1.0"
name = "antlr4-python3-runtime"
version = "4.9.3"
description = "ANTLR 4.9.3 runtime for Python 3.7"
-category = "main"
+category = "dev"
optional = false
python-versions = "*"
@@ -130,7 +130,7 @@ test = ["astroid", "pytest"]
name = "async-timeout"
version = "4.0.2"
description = "Timeout context manager for asyncio programs"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -143,14 +143,6 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
-name = "attr"
-version = "0.3.1"
-description = "Simple decorator to set attributes of target function or class in a DRY way."
-category = "main"
-optional = false
-python-versions = "*"
-
-[[package]]
name = "attrs"
version = "21.4.0"
description = "Classes Without Boilerplate"
@@ -159,9 +151,9 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[package.extras]
-dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"]
-docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
-tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
+dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope-interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"]
+docs = ["furo", "sphinx", "zope-interface", "sphinx-notfound-page"]
+tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope-interface", "cloudpickle"]
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
[[package]]
@@ -250,7 +242,6 @@ six = ">=1.9.0"
webencodings = "*"
[package.extras]
-css = ["tinycss2 (>=1.1.0)"]
dev = ["pip-tools (==6.5.1)", "pytest (==7.1.1)", "flake8 (==4.0.1)", "tox (==3.24.5)", "sphinx (==4.3.2)", "twine (==4.0.0)", "wheel (==0.37.1)", "hashin (==0.17.0)", "black (==22.3.0)", "mypy (==0.942)"]
[[package]]
@@ -263,9 +254,9 @@ python-versions = "*"
[[package]]
name = "cachetools"
-version = "5.1.0"
+version = "5.2.0"
description = "Extensible memoizing collections and decorators"
-category = "main"
+category = "dev"
optional = false
python-versions = "~=3.7"
@@ -273,7 +264,7 @@ python-versions = "~=3.7"
name = "certifi"
version = "2022.5.18.1"
description = "Python package for providing Mozilla's CA Bundle."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -292,7 +283,7 @@ pycparser = "*"
name = "charset-normalizer"
version = "2.0.12"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.5.0"
@@ -311,7 +302,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
name = "colorama"
version = "0.4.4"
description = "Cross-platform colored terminal text."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
@@ -408,7 +399,7 @@ python-versions = "*"
name = "einops"
version = "0.3.2"
description = "A new flavour of deep learning operations"
-category = "main"
+category = "dev"
optional = false
python-versions = "*"
@@ -528,6 +519,7 @@ python-versions = "*"
[package.dependencies]
pycodestyle = "*"
+setuptools = "*"
[[package]]
name = "flake8-polyfill"
@@ -566,7 +558,7 @@ woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"]
name = "frozenlist"
version = "1.3.0"
description = "A list-like structure which implements collections.abc.MutableSequence"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -574,7 +566,7 @@ python-versions = ">=3.7"
name = "fsspec"
version = "2022.5.0"
description = "File-system specification"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -631,7 +623,7 @@ gitdb = ">=4.0.1,<5"
name = "google-auth"
version = "2.6.6"
description = "Google Authentication Library"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
@@ -650,7 +642,7 @@ reauth = ["pyu2f (>=0.1.5)"]
name = "google-auth-oauthlib"
version = "0.4.6"
description = "Google Authentication Library"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -665,7 +657,7 @@ tool = ["click (>=6.0.0)"]
name = "grpcio"
version = "1.46.3"
description = "HTTP/2-based RPC framework"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -690,7 +682,7 @@ numpy = ">=1.14.5"
name = "hydra-core"
version = "1.2.0"
description = "A framework for elegantly configuring complex applications"
-category = "main"
+category = "dev"
optional = false
python-versions = "*"
@@ -703,7 +695,7 @@ packaging = "*"
name = "idna"
version = "3.3"
description = "Internationalized Domain Names in Applications (IDNA)"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.5"
@@ -711,7 +703,7 @@ python-versions = ">=3.5"
name = "importlib-metadata"
version = "4.11.4"
description = "Read metadata from Python packages"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -721,7 +713,7 @@ zipp = ">=0.5"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"]
perf = ["ipython"]
-testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"]
+testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl-flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"]
[[package]]
name = "ipykernel"
@@ -748,7 +740,7 @@ test = ["pytest (>=6.0)", "pytest-cov", "flaky", "ipyparallel", "pre-commit", "p
[[package]]
name = "ipython"
-version = "8.3.0"
+version = "8.4.0"
description = "IPython: Productive Interactive Computing"
category = "dev"
optional = false
@@ -765,6 +757,7 @@ pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""}
pickleshare = "*"
prompt-toolkit = ">=2.0.0,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.1.0"
pygments = ">=2.4.0"
+setuptools = ">=18.5"
stack-data = "*"
traitlets = ">=5"
@@ -1044,7 +1037,7 @@ python-versions = ">=3.7"
name = "loguru"
version = "0.6.0"
description = "Python logging made (stupidly) simple"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.5"
@@ -1059,7 +1052,7 @@ dev = ["colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "tox (>=3.
name = "markdown"
version = "3.3.7"
description = "Python implementation of Markdown."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -1135,7 +1128,7 @@ python-versions = ">=3.5"
name = "multidict"
version = "6.0.2"
description = "multidict implementation"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -1181,7 +1174,7 @@ test = ["pytest", "pytest-tornasync", "pytest-console-scripts"]
[[package]]
name = "nbclient"
-version = "0.6.3"
+version = "0.6.4"
description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
category = "dev"
optional = false
@@ -1191,7 +1184,7 @@ python-versions = ">=3.7.0"
jupyter-client = ">=6.1.5"
nbformat = ">=5.0"
nest-asyncio = "*"
-traitlets = ">=5.0.0"
+traitlets = ">=5.2.2"
[package.extras]
sphinx = ["autodoc-traits", "mock", "moto", "myst-parser", "Sphinx (>=1.7)", "sphinx-book-theme"]
@@ -1325,7 +1318,7 @@ test = ["pytest", "pytest-tornasync", "pytest-console-scripts"]
name = "numpy"
version = "1.22.4"
description = "NumPy is the fundamental package for array computing with Python."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.8"
@@ -1333,7 +1326,7 @@ python-versions = ">=3.8"
name = "oauthlib"
version = "3.2.0"
description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -1344,9 +1337,9 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "omegaconf"
-version = "2.2.1"
+version = "2.2.2"
description = "A flexible configuration library"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -1374,7 +1367,7 @@ numpy = [
name = "packaging"
version = "21.3"
description = "Core utilities for Python packages"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -1448,7 +1441,7 @@ python-versions = "*"
name = "pillow"
version = "9.1.1"
description = "Python Imaging Library (Fork)"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -1507,7 +1500,7 @@ wcwidth = "*"
name = "protobuf"
version = "3.20.1"
description = "Protocol Buffers"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -1553,7 +1546,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
name = "pyasn1"
version = "0.4.8"
description = "ASN.1 types and codecs"
-category = "main"
+category = "dev"
optional = false
python-versions = "*"
@@ -1561,7 +1554,7 @@ python-versions = "*"
name = "pyasn1-modules"
version = "0.2.8"
description = "A collection of ASN.1-based protocols modules."
-category = "main"
+category = "dev"
optional = false
python-versions = "*"
@@ -1588,7 +1581,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
name = "pydeprecate"
version = "0.3.2"
description = "Deprecation tooling"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -1626,7 +1619,7 @@ python-versions = ">=3.6"
name = "pyparsing"
version = "3.0.9"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6.8"
@@ -1706,9 +1699,9 @@ six = ">=1.5"
[[package]]
name = "pytorch-lightning"
-version = "1.6.3"
+version = "1.6.4"
description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -1716,7 +1709,8 @@ python-versions = ">=3.7"
fsspec = {version = ">=2021.05.0,<2021.06.0 || >2021.06.0", extras = ["http"]}
numpy = ">=1.17.2"
packaging = ">=17.0"
-pyDeprecate = ">=0.3.1,<0.4.0"
+protobuf = "<=3.20.1"
+pyDeprecate = ">=0.3.1"
PyYAML = ">=5.4"
tensorboard = ">=2.2.0"
torch = ">=1.8"
@@ -1725,14 +1719,17 @@ tqdm = ">=4.57.0"
typing-extensions = ">=4.0.0"
[package.extras]
-all = ["matplotlib (>3.1)", "horovod (>=0.21.2,!=0.24.0)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"]
-cpu = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"]
-cpu-extra = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)"]
-dev = ["matplotlib (>3.1)", "horovod (>=0.21.2,!=0.24.0)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"]
+all = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.*)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython", "fairscale (>=0.4.5)", "deepspeed", "horovod (>=0.21.2,!=0.24.0)", "hivemind (>=1.0.1)"]
+deepspeed = ["deepspeed"]
+dev = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.*)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"]
examples = ["torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"]
-extra = ["matplotlib (>3.1)", "horovod (>=0.21.2,!=0.24.0)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,<10.15.0 || >=10.16.0)"]
+extra = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.*)"]
+fairscale = ["fairscale (>=0.4.5)"]
+hivemind = ["hivemind (>=1.0.1)"]
+horovod = ["horovod (>=0.21.2,!=0.24.0)"]
loggers = ["neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)"]
-test = ["coverage (>5.2.0,<6.3)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "twine (==3.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "sklearn", "jsonargparse", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"]
+strategies = ["fairscale (>=0.4.5)", "deepspeed", "horovod (>=0.21.2,!=0.24.0)", "hivemind (>=1.0.1)"]
+test = ["coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"]
[[package]]
name = "pytz"
@@ -1762,7 +1759,7 @@ python-versions = ">=3.7"
name = "pyyaml"
version = "6.0"
description = "YAML parser and emitter for Python"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -1826,7 +1823,7 @@ python-versions = ">=3.6"
name = "requests"
version = "2.27.1"
description = "Python HTTP for Humans."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
@@ -1844,7 +1841,7 @@ use_chardet_on_py3 = ["chardet (>=3.0.2,<5)"]
name = "requests-oauthlib"
version = "1.3.1"
description = "OAuthlib authentication support for Requests."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
@@ -1859,7 +1856,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
name = "rsa"
version = "4.8"
description = "Pure-Python RSA implementation"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6,<4"
@@ -1879,6 +1876,7 @@ Click = ">=6.0"
dparse = ">=0.5.1"
packaging = "*"
requests = "*"
+setuptools = "*"
[[package]]
name = "scipy"
@@ -1946,6 +1944,18 @@ python-versions = ">=3.6"
test = ["pytest"]
[[package]]
+name = "setuptools"
+version = "59.5.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+
+[package.extras]
+docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "sphinx-inline-tabs", "sphinxcontrib-towncrier", "furo"]
+testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "mock", "flake8-2020", "virtualenv (>=13.0.0)", "pytest-virtualenv (>=1.2.7)", "wheel", "paver", "pip (>=19.1)", "jaraco.envs (>=2.2)", "pytest-xdist", "sphinx", "jaraco.path (>=3.2.0)", "pytest-black (>=0.3.7)", "pytest-mypy"]
+
+[[package]]
name = "setuptools-scm"
version = "6.4.2"
description = "the blessed package to manage your versions by scm tags"
@@ -1955,6 +1965,7 @@ python-versions = ">=3.6"
[package.dependencies]
packaging = ">=20.0"
+setuptools = "*"
tomli = ">=1.0.0"
[package.extras]
@@ -1973,7 +1984,7 @@ python-versions = ">=3.5"
name = "six"
version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
@@ -1981,7 +1992,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
name = "smart-open"
version = "5.2.1"
description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6,<4.0"
@@ -2057,7 +2068,7 @@ pbr = ">=2.0.0,<2.1.0 || >2.1.0"
name = "tensorboard"
version = "2.9.0"
description = "TensorBoard lets you watch Tensors Flow"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -2070,15 +2081,17 @@ markdown = ">=2.6.8"
numpy = ">=1.12.0"
protobuf = ">=3.9.2"
requests = ">=2.21.0,<3"
+setuptools = ">=41.0.0"
tensorboard-data-server = ">=0.6.0,<0.7.0"
tensorboard-plugin-wit = ">=1.6.0"
werkzeug = ">=1.0.1"
+wheel = ">=0.26"
[[package]]
name = "tensorboard-data-server"
version = "0.6.1"
description = "Fast data loading for TensorBoard"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -2086,7 +2099,7 @@ python-versions = ">=3.6"
name = "tensorboard-plugin-wit"
version = "1.8.1"
description = "What-If Tool TensorBoard plugin."
-category = "main"
+category = "dev"
optional = false
python-versions = "*"
@@ -2141,7 +2154,7 @@ python-versions = ">=3.7"
name = "torch"
version = "1.11.0"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7.0"
@@ -2150,7 +2163,7 @@ typing-extensions = "*"
[[package]]
name = "torchinfo"
-version = "1.6.6"
+version = "1.7.0"
description = "Model summary in PyTorch, based off of the original torchsummary."
category = "dev"
optional = false
@@ -2160,7 +2173,7 @@ python-versions = ">=3.7"
name = "torchmetrics"
version = "0.4.1"
description = "PyTorch native Metrics"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -2176,7 +2189,7 @@ image = ["scipy", "torchvision", "torch-fidelity"]
name = "torchvision"
version = "0.12.0"
description = "image and video datasets and models for torch deep learning"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -2202,7 +2215,7 @@ python-versions = ">= 3.5"
name = "tqdm"
version = "4.64.0"
description = "Fast, Extensible Progress Meter"
-category = "main"
+category = "dev"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
@@ -2217,7 +2230,7 @@ telegram = ["requests"]
[[package]]
name = "traitlets"
-version = "5.2.1.post0"
+version = "5.2.2.post1"
description = ""
category = "dev"
optional = false
@@ -2250,7 +2263,7 @@ test = ["pytest", "typing-extensions", "mypy"]
name = "typing-extensions"
version = "4.2.0"
description = "Backported and Experimental Type Hints for Python 3.7+"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -2258,7 +2271,7 @@ python-versions = ">=3.7"
name = "urllib3"
version = "1.26.9"
description = "HTTP library with thread-safe connection pooling, file post, and more."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
@@ -2269,7 +2282,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
name = "wandb"
-version = "0.12.16"
+version = "0.12.17"
description = "A CLI and library for interacting with the Weights and Biases API."
category = "dev"
optional = false
@@ -2281,13 +2294,14 @@ docker-pycreds = ">=0.4.0"
GitPython = ">=1.0.0"
pathtools = "*"
promise = ">=2.0,<3"
-protobuf = ">=3.12.0"
+protobuf = ">=3.12.0,<4.0dev"
psutil = ">=5.0.0"
python-dateutil = ">=2.6.1"
PyYAML = "*"
requests = ">=2.0.0,<3"
sentry-sdk = ">=1.0.0"
setproctitle = "*"
+setuptools = "*"
shortuuid = ">=0.5.0"
six = ">=1.13.0"
@@ -2334,7 +2348,7 @@ test = ["websockets"]
name = "werkzeug"
version = "2.1.2"
description = "The comprehensive WSGI web application library."
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
@@ -2342,6 +2356,17 @@ python-versions = ">=3.7"
watchdog = ["watchdog"]
[[package]]
+name = "wheel"
+version = "0.37.1"
+description = "A built-package format for Python"
+category = "dev"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
+
+[package.extras]
+test = ["pytest (>=3.0.0)", "pytest-cov"]
+
+[[package]]
name = "widgetsnbextension"
version = "3.6.0"
description = "IPython HTML widgets for Jupyter"
@@ -2356,7 +2381,7 @@ notebook = ">=4.4.1"
name = "win32-setctime"
version = "1.1.0"
description = "A small Python utility to set file creation time on Windows"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.5"
@@ -2367,7 +2392,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
name = "yarl"
version = "1.7.2"
description = "Yet another URL library"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.6"
@@ -2379,18 +2404,18 @@ multidict = ">=4.0"
name = "zipp"
version = "3.8.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
-category = "main"
+category = "dev"
optional = false
python-versions = ">=3.7"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"]
-testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"]
+testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco-itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
-content-hash = "608eb05815709f89d2b622af655b68624f43064413bb4cd6e5d681029bbb76c7"
+content-hash = "b0c8b83118a86644eb9ceb5e33f9944d13016592a849111163a853dca3adb2aa"
[metadata.files]
absl-py = [
@@ -2529,10 +2554,6 @@ atomicwrites = [
{file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"},
{file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
]
-attr = [
- {file = "attr-0.3.1-py2-none-any.whl", hash = "sha256:0b1aaddb85bd9e9c4bd75092f4440d6616ff40b0df0437f00771871670f7c9fd"},
- {file = "attr-0.3.1.tar.gz", hash = "sha256:9091548058d17f132596e61fa7518e504f76b9a4c61ca7d86e1f96dbf7d4775d"},
-]
attrs = [
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
@@ -2566,8 +2587,8 @@ boltons = [
{file = "boltons-20.2.1.tar.gz", hash = "sha256:dd362291a460cc1e0c2e91cc6a60da3036ced77099b623112e8f833e6734bdc5"},
]
cachetools = [
- {file = "cachetools-5.1.0-py3-none-any.whl", hash = "sha256:4ebbd38701cdfd3603d1f751d851ed248ab4570929f2d8a7ce69e30c420b141c"},
- {file = "cachetools-5.1.0.tar.gz", hash = "sha256:8b3b8fa53f564762e5b221e9896798951e7f915513abf2ba072ce0f07f3f5a98"},
+ {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"},
+ {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"},
]
certifi = [
{file = "certifi-2022.5.18.1-py3-none-any.whl", hash = "sha256:f1d53542ee8cbedbe2118b5686372fb33c297fcd6379b050cca0ef13a597382a"},
@@ -2994,8 +3015,8 @@ ipykernel = [
{file = "ipykernel-6.13.0.tar.gz", hash = "sha256:0e28273e290858393e86e152b104e5506a79c13d25b951ac6eca220051b4be60"},
]
ipython = [
- {file = "ipython-8.3.0-py3-none-any.whl", hash = "sha256:341456643a764c28f670409bbd5d2518f9b82c013441084ff2c2fc999698f83b"},
- {file = "ipython-8.3.0.tar.gz", hash = "sha256:807ae3cf43b84693c9272f70368440a9a7eaa2e7e6882dad943c32fbf7e51402"},
+ {file = "ipython-8.4.0-py3-none-any.whl", hash = "sha256:7ca74052a38fa25fe9bedf52da0be7d3fdd2fb027c3b778ea78dfe8c212937d1"},
+ {file = "ipython-8.4.0.tar.gz", hash = "sha256:f2db3a10254241d9b447232cec8b424847f338d9d36f9a577a6192c332a46abd"},
]
ipython-genutils = [
{file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
@@ -3295,8 +3316,8 @@ nbclassic = [
{file = "nbclassic-0.3.7.tar.gz", hash = "sha256:36dbaa88ffaf5dc05d149deb97504b86ba648f4a80a60b8a58ac94acab2daeb5"},
]
nbclient = [
- {file = "nbclient-0.6.3-py3-none-any.whl", hash = "sha256:2747ac9b385720d8a6c34f2f71e72cbe64aec6cadaadcc064a4df0b0e99c5874"},
- {file = "nbclient-0.6.3.tar.gz", hash = "sha256:b80726fc1fb89a0e8f8be1e77e28d0026b1e8ed90bc143c8a0c7622e4f8cdd9e"},
+ {file = "nbclient-0.6.4-py3-none-any.whl", hash = "sha256:f251bba200a2b401a061dfd700a7a70b5772f664fb49d4a2d3e5536ec0e98c76"},
+ {file = "nbclient-0.6.4.tar.gz", hash = "sha256:cdef7757cead1735d2c70cc66095b072dced8a1e6d1c7639ef90cd3e04a11f2e"},
]
nbconvert = [
{file = "nbconvert-6.5.0-py3-none-any.whl", hash = "sha256:c56dd0b8978a1811a5654f74c727ff16ca87dd5a43abd435a1c49b840fcd8360"},
@@ -3351,8 +3372,8 @@ oauthlib = [
{file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"},
]
omegaconf = [
- {file = "omegaconf-2.2.1-py3-none-any.whl", hash = "sha256:5ce512b0a8996b5acddc7b30f6bffa337a4b0ada1d96b4270588365e2a69e6d5"},
- {file = "omegaconf-2.2.1.tar.gz", hash = "sha256:6796a5b51a0112410d9940c3a5d2938d5e644377a5517c02db1aef99e03b8af2"},
+ {file = "omegaconf-2.2.2-py3-none-any.whl", hash = "sha256:556917181487fb66fe832d3c7b324f51b2f4c8adc373dd5091be921501b7d420"},
+ {file = "omegaconf-2.2.2.tar.gz", hash = "sha256:65c85b2a84669a570c70f2df00de3cebcd9b47a8587d3c53b1aa5766bb096f77"},
]
opencv-python = [
{file = "opencv-python-4.5.5.64.tar.gz", hash = "sha256:f65de0446a330c3b773cd04ba10345d8ce1b15dcac3f49770204e37602d0b3f7"},
@@ -3619,8 +3640,8 @@ python-dateutil = [
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
]
pytorch-lightning = [
- {file = "pytorch-lightning-1.6.3.tar.gz", hash = "sha256:beb1f36a6dae91f5fef0959a04af1092dff4f3f4d99c20f0e033f84e615903e3"},
- {file = "pytorch_lightning-1.6.3-py3-none-any.whl", hash = "sha256:5419adaee5bb8057b1dad69d2cbb79f823f54a94bb67cb47fe75cdf8c1bc5616"},
+ {file = "pytorch-lightning-1.6.4.tar.gz", hash = "sha256:5459f2c3e67676ec59e94576d1499e9559d214e7df41eadd135db64b4ccf54b9"},
+ {file = "pytorch_lightning-1.6.4-py3-none-any.whl", hash = "sha256:0f42f93116a3fcb6fd8c9ea45cf7c918e4aa3f848ae21d0e9ac2bf39f2865dd7"},
]
pytz = [
{file = "pytz-2022.1-py2.py3-none-any.whl", hash = "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c"},
@@ -3945,6 +3966,10 @@ setproctitle = [
{file = "setproctitle-1.2.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:97accd117392b1e57e09888792750c403d7729b7e4b193005178b3736b325ea0"},
{file = "setproctitle-1.2.3.tar.gz", hash = "sha256:ecf28b1c07a799d76f4326e508157b71aeda07b84b90368ea451c0710dbd32c0"},
]
+setuptools = [
+ {file = "setuptools-59.5.0-py3-none-any.whl", hash = "sha256:6d10741ff20b89cd8c6a536ee9dc90d3002dec0226c78fb98605bfb9ef8a7adf"},
+ {file = "setuptools-59.5.0.tar.gz", hash = "sha256:d144f85102f999444d06f9c0e8c737fd0194f10f2f7e5fdb77573f6e2fa4fad0"},
+]
setuptools-scm = [
{file = "setuptools_scm-6.4.2-py3-none-any.whl", hash = "sha256:acea13255093849de7ccb11af9e1fb8bde7067783450cee9ef7a93139bddf6d4"},
{file = "setuptools_scm-6.4.2.tar.gz", hash = "sha256:6833ac65c6ed9711a4d5d2266f8024cfa07c533a0e55f4c12f6eff280a5a9e30"},
@@ -4034,8 +4059,8 @@ torch = [
{file = "torch-1.11.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:0e48af66ad755f0f9c5f2664028a414f57c49d6adc37e77e06fe0004da4edb61"},
]
torchinfo = [
- {file = "torchinfo-1.6.6-py3-none-any.whl", hash = "sha256:406965434fd768e1ef5c373be6e73fc0e91b0f54a8818165b46297bfc493a8ad"},
- {file = "torchinfo-1.6.6.tar.gz", hash = "sha256:8eb0a38c7dc2403b8f1eab101eb7cca012fb2eaf33caf248acb67c996836b70e"},
+ {file = "torchinfo-1.7.0-py3-none-any.whl", hash = "sha256:23d8771d965e50cdd242327ea4669b20978db0235da46ad7889be1b321598fea"},
+ {file = "torchinfo-1.7.0.tar.gz", hash = "sha256:24cf949cc65d3926638e845e0aa949feb65f3025c58a8b8e084969bd42e30c0b"},
]
torchmetrics = [
{file = "torchmetrics-0.4.1-py3-none-any.whl", hash = "sha256:70c83f0fc804a4fe00a9e72dbd2960ff76e39ef62570a19bbdce0c15a1ee0d71"},
@@ -4110,8 +4135,8 @@ tqdm = [
{file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"},
]
traitlets = [
- {file = "traitlets-5.2.1.post0-py3-none-any.whl", hash = "sha256:f44b708d33d98b0addb40c29d148a761f44af740603a8fd0e2f8b5b27cf0f087"},
- {file = "traitlets-5.2.1.post0.tar.gz", hash = "sha256:70815ecb20ec619d1af28910ade523383be13754283aef90528eb3d47b77c5db"},
+ {file = "traitlets-5.2.2.post1-py3-none-any.whl", hash = "sha256:1530d04badddc6a73d50b7ee34667d4b96914da352109117b4280cb56523a51b"},
+ {file = "traitlets-5.2.2.post1.tar.gz", hash = "sha256:74803a1baa59af70f023671d86d5c7a834c931186df26d50d362ee6a1ff021fd"},
]
typed-ast = [
{file = "typed_ast-1.4.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:2068531575a125b87a41802130fa7e29f26c09a2833fea68d9a40cf33902eba6"},
@@ -4158,8 +4183,8 @@ urllib3 = [
{file = "urllib3-1.26.9.tar.gz", hash = "sha256:aabaf16477806a5e1dd19aa41f8c2b7950dd3c746362d7e3223dbe6de6ac448e"},
]
wandb = [
- {file = "wandb-0.12.16-py2.py3-none-any.whl", hash = "sha256:ed7782dadfb5bc457998eccd995f88ae564cdf2a36b12024e4a5d9a47b1b84e8"},
- {file = "wandb-0.12.16.tar.gz", hash = "sha256:a738b5eb61081fa96fc2e16ffaf6dbde67b78f973ff45bda61ed93659ca09912"},
+ {file = "wandb-0.12.17-py2.py3-none-any.whl", hash = "sha256:40e599ed7a4a633a4e1da77d026ee872fcb60a207aafbd1bf8ec1ab5b8171ccf"},
+ {file = "wandb-0.12.17.tar.gz", hash = "sha256:ad2fe5a9cbb44c445cb0cfc6d04804f74dca2999ae98f0c8db93721b522f76f1"},
]
wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
@@ -4177,6 +4202,10 @@ werkzeug = [
{file = "Werkzeug-2.1.2-py3-none-any.whl", hash = "sha256:72a4b735692dd3135217911cbeaa1be5fa3f62bffb8745c5215420a03dc55255"},
{file = "Werkzeug-2.1.2.tar.gz", hash = "sha256:1ce08e8093ed67d638d63879fd1ba3735817f7a80de3674d293f5984f25fb6e6"},
]
+wheel = [
+ {file = "wheel-0.37.1-py2.py3-none-any.whl", hash = "sha256:4bdcd7d840138086126cd09254dc6195fb4fc6f01c050a1d7236f2630db1d22a"},
+ {file = "wheel-0.37.1.tar.gz", hash = "sha256:e9a504e793efbca1b8e0e9cb979a249cf4a0a7b5b8c9e8b65a5e39d49529c1c4"},
+]
widgetsnbextension = [
{file = "widgetsnbextension-3.6.0-py2.py3-none-any.whl", hash = "sha256:4fd321cad39fdcf8a8e248a657202d42917ada8e8ed5dd3f60f073e0d54ceabd"},
{file = "widgetsnbextension-3.6.0.tar.gz", hash = "sha256:e84a7a9fcb9baf3d57106e184a7389a8f8eb935bf741a5eb9d60aa18cc029a80"},
diff --git a/pyproject.toml b/pyproject.toml
index 2e7cf28..2f08ace 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,10 +16,10 @@ omegaconf = "^2.1.0"
einops = "^0.3.0"
pytorch-lightning = "^1.6.3"
hydra-core = "^1.1.1"
-attr = "^0.3.1"
smart-open = "^5.2.1"
torch = "^1.11.0"
torchvision = "^0.12.0"
+attrs = "^21.4.0"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 15c286a..a0c8416 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,7 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-import attr
+from attrs import define, field
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
@@ -20,29 +20,29 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@attr.s(repr=False)
+@define(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __attrs_pre_init__(self) -> None:
+ def __attrs_post_init__(self) -> None:
"""Pre init constructor."""
super().__init__()
- mapping: Type[AbstractMapping] = attr.ib()
- transform: Optional[Callable] = attr.ib(default=None)
- test_transform: Optional[Callable] = attr.ib(default=None)
- target_transform: Optional[Callable] = attr.ib(default=None)
- train_fraction: float = attr.ib(default=0.8)
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
- pin_memory: bool = attr.ib(default=True)
+ mapping: Type[AbstractMapping] = field()
+ transform: Optional[Callable] = field(default=None)
+ test_transform: Optional[Callable] = field(default=None)
+ target_transform: Optional[Callable] = field(default=None)
+ train_fraction: float = field(default=0.8)
+ batch_size: int = field(default=16)
+ num_workers: int = field(default=0)
+ pin_memory: bool = field(default=True)
# Placeholders
- data_train: BaseDataset = attr.ib(init=False, default=None)
- data_val: BaseDataset = attr.ib(init=False, default=None)
- data_test: BaseDataset = attr.ib(init=False, default=None)
- dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ data_train: BaseDataset = field(init=False, default=None)
+ data_val: BaseDataset = field(init=False, default=None)
+ data_test: BaseDataset = field(init=False, default=None)
+ dims: Tuple[int, ...] = field(init=False, default=None)
+ output_dims: Tuple[int, ...] = field(init=False, default=None)
@classmethod
def data_dirname(cls: T) -> Path:
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index b9567c7..c57cbcc 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,7 +1,7 @@
"""Base PyTorch Dataset class."""
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
-import attr
+from attrs import define, field
import torch
from torch import Tensor
from torch.utils.data import Dataset
@@ -9,7 +9,7 @@ from torch.utils.data import Dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@attr.s
+@define
class BaseDataset(Dataset):
r"""Base Dataset class that processes data and targets through optional transfroms.
@@ -21,10 +21,10 @@ class BaseDataset(Dataset):
target transforms.
"""
- data: Union[Sequence, Tensor] = attr.ib()
- targets: Union[Sequence, Tensor] = attr.ib()
- transform: Union[Optional[Callable], str] = attr.ib(default=None)
- target_transform: Union[Optional[Callable], str] = attr.ib(default=None)
+ data: Union[Sequence, Tensor] = field()
+ targets: Union[Sequence, Tensor] = field()
+ transform: Union[Optional[Callable], str] = field(default=None)
+ target_transform: Union[Optional[Callable], str] = field(default=None)
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index dc8d31a..94882bf 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -6,7 +6,7 @@ import shutil
from typing import Dict, List, Optional, Sequence, Set, Tuple
import zipfile
-import attr
+from attrs import define
import h5py
from loguru import logger as log
import numpy as np
@@ -35,7 +35,7 @@ ESSENTIALS_FILENAME = (
)
-@attr.s(auto_attribs=True)
+@define(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index c267286..43d55b9 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -3,7 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import DefaultDict, List, Tuple
-import attr
+from attrs import define, field
import h5py
from loguru import logger as log
import numpy as np
@@ -33,17 +33,17 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- max_length: int = attr.ib(default=128)
- min_overlap: float = attr.ib(default=0.0)
- max_overlap: float = attr.ib(default=0.33)
- num_train: int = attr.ib(default=10_000)
- num_val: int = attr.ib(default=2_000)
- num_test: int = attr.ib(default=2_000)
- emnist: EMNIST = attr.ib(init=False, default=None)
+ max_length: int = field(default=128)
+ min_overlap: float = field(default=0.0)
+ max_overlap: float = field(default=0.33)
+ num_train: int = field(default=10_000)
+ num_val: int = field(default=2_000)
+ num_test: int = field(default=2_000)
+ emnist: EMNIST = field(init=False, default=None)
def __attrs_post_init__(self) -> None:
"""Post init constructor."""
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 766f3e0..8166863 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -9,7 +9,7 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
-import attr
+from attrs import define, field
from boltons.cacheutils import cachedproperty
from loguru import logger as log
import toml
@@ -27,7 +27,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
-@attr.s(auto_attribs=True)
+@define(auto_attribs=True)
class IAM(BaseDataModule):
r"""The IAM Lines dataset.
@@ -44,7 +44,7 @@ class IAM(BaseDataModule):
contributed to one set only.
"""
- metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))
+ metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME))
def prepare_data(self) -> None:
"""Prepares the IAM dataset."""
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 22d00f1..52c10c3 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,5 +1,5 @@
"""IAM original and sythetic dataset class."""
-import attr
+from attrs import define, field
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -8,7 +8,7 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
"""A dataset with synthetic and real handwritten paragraph."""
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index a79c202..34cf605 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,7 +7,7 @@ import json
from pathlib import Path
from typing import List, Sequence, Tuple
-import attr
+from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
@@ -35,14 +35,14 @@ MAX_LABEL_LENGTH = 89
MAX_WORD_PIECE_LENGTH = 72
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- dims: Tuple[int, int, int] = attr.ib(
+ dims: Tuple[int, int, int] = field(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
- output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
+ output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 033b93e..b605bbc 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,7 +3,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
-import attr
+from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
@@ -33,15 +33,15 @@ MAX_LABEL_LENGTH = 682
MAX_WORD_PIECE_LENGTH = 451
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
# Placeholders
- dims: Tuple[int, int, int] = attr.ib(
+ dims: Tuple[int, int, int] = field(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
- output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
+ output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
def prepare_data(self) -> None:
"""Create data for training/testing."""
@@ -86,7 +86,10 @@ class IAMParagraphs(BaseDataModule):
length=self.output_dims[0],
)
return BaseDataset(
- data, targets, transform=transform, target_transform=target_transform,
+ data,
+ targets,
+ transform=transform,
+ target_transform=target_transform,
)
log.info(f"Loading IAM paragraph regions and lines for {stage}...")
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index ea59098..7143951 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,7 +2,7 @@
import random
from typing import Any, List, Sequence, Tuple
-import attr
+from attrs import define
from loguru import logger as log
import numpy as np
from PIL import Image
@@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = (
)
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 821cb69..bf3bc08 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,7 +1,7 @@
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Optional, Tuple, Type
-import attr
+from attrs import define, field
import hydra
from loguru import logger as log
from omegaconf import DictConfig
@@ -14,7 +14,7 @@ import torchmetrics
from text_recognizer.data.mappings.base import AbstractMapping
-@attr.s(eq=False)
+@define(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
@@ -22,22 +22,18 @@ class BaseLitModel(LightningModule):
"""Pre init constructor."""
super().__init__()
- network: Type[nn.Module] = attr.ib()
- loss_fn: Type[nn.Module] = attr.ib()
- optimizer_configs: DictConfig = attr.ib()
- lr_scheduler_configs: Optional[DictConfig] = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
+ network: Type[nn.Module] = field()
+ loss_fn: Type[nn.Module] = field()
+ optimizer_configs: DictConfig = field()
+ lr_scheduler_configs: Optional[DictConfig] = field()
+ mapping: Type[AbstractMapping] = field()
# Placeholders
- train_acc: torchmetrics.Accuracy = attr.ib(
- init=False, default=torchmetrics.Accuracy()
- )
- val_acc: torchmetrics.Accuracy = attr.ib(
- init=False, default=torchmetrics.Accuracy()
- )
- test_acc: torchmetrics.Accuracy = attr.ib(
+ train_acc: torchmetrics.Accuracy = field(
init=False, default=torchmetrics.Accuracy()
)
+ val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
+ test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
def optimizer_zero_grad(
self,
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index f83c9e4..e59a830 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,20 +1,20 @@
"""Character Error Rate (CER)."""
from typing import Set
-import attr
+from attrs import define, field
import editdistance
import torch
from torch import Tensor
from torchmetrics import Metric
-@attr.s(eq=False)
+@define(eq=False)
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_indices: Set[Tensor] = attr.ib(converter=set)
- error: Tensor = attr.ib(init=False)
- total: Tensor = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = field(converter=set)
+ error: Tensor = field(init=False)
+ total: Tensor = field(init=False)
def __attrs_post_init__(self) -> None:
super().__init__()
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 7272f46..c5120fe 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,7 +1,7 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Set, Tuple
-import attr
+from attrs import define, field
import torch
from torch import Tensor
@@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel
from text_recognizer.models.metrics import CharacterErrorRate
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- max_output_len: int = attr.ib(default=451)
- start_token: str = attr.ib(default="<s>")
- end_token: str = attr.ib(default="<e>")
- pad_token: str = attr.ib(default="<p>")
+ max_output_len: int = field(default=451)
+ start_token: str = field(default="<s>")
+ end_token: str = field(default="<e>")
+ pad_token: str = field(default="<p>")
- start_index: int = attr.ib(init=False)
- end_index: int = attr.ib(init=False)
- pad_index: int = attr.ib(init=False)
+ start_index: int = field(init=False)
+ end_index: int = field(init=False)
+ pad_index: int = field(init=False)
- ignore_indices: Set[Tensor] = attr.ib(init=False)
- val_cer: CharacterErrorRate = attr.ib(init=False)
- test_cer: CharacterErrorRate = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = field(init=False)
+ val_cer: CharacterErrorRate = field(init=False)
+ test_cer: CharacterErrorRate = field(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 4c9ed75..cf64bcf 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,7 @@
"""Efficientnet backbone."""
from typing import Tuple
-import attr
+from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@attr.s(eq=False)
+@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
@@ -33,28 +33,28 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = attr.ib()
- params: Tuple[float, float, float] = attr.ib(default=None, init=False)
- stochastic_dropout_rate: float = attr.ib(default=0.2)
- bn_momentum: float = attr.ib(default=0.99)
- bn_eps: float = attr.ib(default=1.0e-3)
- depth: int = attr.ib(default=7)
- out_channels: int = attr.ib(default=None, init=False)
- _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
- _blocks: nn.ModuleList = attr.ib(default=None, init=False)
- _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+ arch: str = field()
+ params: Tuple[float, float, float] = field(default=None, init=False)
+ stochastic_dropout_rate: float = field(default=0.2)
+ bn_momentum: float = field(default=0.99)
+ bn_eps: float = field(default=1.0e-3)
+ depth: int = field(default=7)
+ out_channels: int = field(default=None, init=False)
+ _conv_stem: nn.Sequential = field(default=None, init=False)
+ _blocks: nn.ModuleList = field(default=None, init=False)
+ _conv_head: nn.Sequential = field(default=None, init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self._build()
@depth.validator
- def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_depth(self, attribute, value: str) -> None:
if not 5 <= value <= 7:
raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
@arch.validator
- def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_arch(self, attribute, value: str) -> None:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
@@ -88,7 +88,9 @@ class EfficientNet(nn.Module):
for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
- **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
+ **args,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
)
)
args.in_channels = args.out_channels
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index beb7d57..98e9353 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -1,7 +1,7 @@
"""Mobile inverted residual block."""
from typing import Optional, Tuple, Union
-import attr
+from attrs import define, field
import torch
from torch import nn, Tensor
import torch.nn.functional as F
@@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
return (stride,) * 2 if isinstance(stride, int) else stride
-@attr.s(eq=False)
+@define(eq=False)
class BaseModule(nn.Module):
"""Base sub module class."""
- bn_momentum: float = attr.ib()
- bn_eps: float = attr.ib()
- block: nn.Sequential = attr.ib(init=False)
+ bn_momentum: float = field()
+ bn_eps: float = field()
+ block: nn.Sequential = field(init=False)
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -36,12 +36,12 @@ class BaseModule(nn.Module):
return self.block(x)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class InvertedBottleneck(BaseModule):
"""Inverted bottleneck module."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
+ in_channels: int = field()
+ out_channels: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class Depthwise(BaseModule):
"""Depthwise convolution module."""
- channels: int = attr.ib()
- kernel_size: int = attr.ib()
- stride: int = attr.ib()
+ channels: int = field()
+ kernel_size: int = field()
+ stride: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -85,13 +85,13 @@ class Depthwise(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class SqueezeAndExcite(BaseModule):
"""Sequeeze and excite module."""
- in_channels: int = attr.ib()
- channels: int = attr.ib()
- se_ratio: float = attr.ib()
+ in_channels: int = field()
+ channels: int = field()
+ se_ratio: float = field()
def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class Pointwise(BaseModule):
"""Pointwise module."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
+ in_channels: int = field()
+ out_channels: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -133,32 +133,35 @@ class Pointwise(BaseModule):
)
-@attr.s(eq=False)
+@define(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
def __attrs_pre_init__(self) -> None:
super().__init__()
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- kernel_size: Tuple[int, int] = attr.ib()
- stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
- bn_momentum: float = attr.ib()
- bn_eps: float = attr.ib()
- se_ratio: float = attr.ib()
- expand_ratio: int = attr.ib()
- pad: Tuple[int, int, int, int] = attr.ib(init=False)
- _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False)
- _depthwise: nn.Sequential = attr.ib(init=False)
- _squeeze_excite: nn.Sequential = attr.ib(init=False)
- _pointwise: nn.Sequential = attr.ib(init=False)
+ in_channels: int = field()
+ out_channels: int = field()
+ kernel_size: Tuple[int, int] = field()
+ stride: Tuple[int, int] = field(converter=_convert_stride)
+ bn_momentum: float = field()
+ bn_eps: float = field()
+ se_ratio: float = field()
+ expand_ratio: int = field()
+ pad: Tuple[int, int, int, int] = field(init=False)
+ _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False)
+ _depthwise: nn.Sequential = field(init=False)
+ _squeeze_excite: nn.Sequential = field(init=False)
+ _pointwise: nn.Sequential = field(init=False)
@pad.default
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+ return (
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 87792a9..aa15b88 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,7 +1,7 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-import attr
+from attrs import define, field
from einops import rearrange
import torch
from torch import einsum
@@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import (
)
-@attr.s(eq=False)
+@define(eq=False)
class Attention(nn.Module):
"""Standard attention."""
def __attrs_pre_init__(self) -> None:
super().__init__()
- dim: int = attr.ib()
- num_heads: int = attr.ib()
- causal: bool = attr.ib(default=False)
- dim_head: int = attr.ib(default=64)
- dropout_rate: float = attr.ib(default=0.0)
- rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
- scale: float = attr.ib(init=False)
- dropout: nn.Dropout = attr.ib(init=False)
- fc: nn.Linear = attr.ib(init=False)
+ dim: int = field()
+ num_heads: int = field()
+ causal: bool = field(default=False)
+ dim_head: int = field(default=64)
+ dropout_rate: float = field(default=0.0)
+ rotary_embedding: Optional[RotaryEmbedding] = field(default=None)
+ scale: float = field(init=False)
+ dropout: nn.Dropout = field(init=False)
+ fc: nn.Linear = field(init=False)
def __attrs_post_init__(self) -> None:
self.scale = self.dim ** -0.5
@@ -120,7 +120,6 @@ def apply_input_mask(
input_mask = q_mask * k_mask
energy = energy.masked_fill_(~input_mask, mask_value)
- del input_mask
return energy
@@ -133,5 +132,4 @@ def apply_causal_mask(
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
energy.masked_fill_(mask, mask_value)
- del mask
return energy