summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
-rw-r--r--.flake84
-rw-r--r--README.md68
-rw-r--r--poetry.lock495
-rw-r--r--pyproject.toml14
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb312
-rw-r--r--src/notebooks/02c-image-patches.ipynb65
-rw-r--r--src/notebooks/04a-look-at-iam-lines.ipynb295
-rw-r--r--src/notebooks/06-try-transformer-model-predictions.ipynb (renamed from src/notebooks/Untitled.ipynb)0
-rw-r--r--src/notebooks/07-look-at-lexicon.ipynb1119
-rw-r--r--src/notebooks/07-try-gtn.ipynb155
-rw-r--r--src/notebooks/g1.pngbin0 -> 8590 bytes
-rw-r--r--src/notebooks/g2.pngbin0 -> 5247 bytes
-rw-r--r--src/notebooks/intersect.pngbin0 -> 7953 bytes
-rw-r--r--src/tasks/build_transitions.py263
-rw-r--r--src/tasks/make_wordpieces.py114
-rw-r--r--src/text_recognizer/datasets/__init__.py3
-rw-r--r--src/text_recognizer/datasets/iam_preprocessor.py196
-rw-r--r--src/text_recognizer/datasets/transforms.py45
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/base.py2
-rw-r--r--src/text_recognizer/models/transformer_model.py12
-rw-r--r--src/text_recognizer/models/vqvae_model.py80
-rw-r--r--src/text_recognizer/networks/__init__.py8
-rw-r--r--src/text_recognizer/networks/cnn.py101
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py15
-rw-r--r--src/text_recognizer/networks/metrics.py33
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py2
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py205
-rw-r--r--src/text_recognizer/networks/util.py9
-rw-r--r--src/text_recognizer/networks/vq_transformer.py150
-rw-r--r--src/text_recognizer/networks/vqvae/__init__.py4
-rw-r--r--src/text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--src/text_recognizer/networks/vqvae/encoder.py125
-rw-r--r--src/text_recognizer/networks/vqvae/vector_quantizer.py2
-rw-r--r--src/text_recognizer/networks/vqvae/vqvae.py74
-rw-r--r--src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.ptbin0 -> 21687018 bytes
-rw-r--r--src/training/run_experiment.py7
-rw-r--r--src/training/trainer/callbacks/__init__.py8
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py58
-rw-r--r--src/training/trainer/train.py94
40 files changed, 3541 insertions, 731 deletions
diff --git a/.flake8 b/.flake8
index 8e1c0eb..eff48a6 100644
--- a/.flake8
+++ b/.flake8
@@ -1,9 +1,9 @@
[flake8]
select = ANN,B,B9,BLK,C,D,DAR,E,F,I,S,W
-ignore = E203,E501,W503,ANN101,ANN002,ANN003,F401,D202,S404,D107,S607,S603,S310,S106
+ignore = E203,E501,W503,ANN101,ANN002,ANN003,F401,D202,S404,D107,S607,S603,S310,S106,S311
max-line-length = 120
max-complexity = 10
application-import-names = text_recognizer,tests
import-order-style = google
docstring-convention = google
-per-file-ignores = tests/*:S101,tests/*:S106,src/text_recognizer/datasets/*:S110,src/training/callbacks/*:B006
+per-file-ignores = tests/*:S101,tests/*:S106,src/text_recognizer/datasets/*:S110,src/training/callbacks/*:B006,src/tasks/build_transitions.py:C901
diff --git a/README.md b/README.md
index 0330372..a589c92 100644
--- a/README.md
+++ b/README.md
@@ -1,64 +1,28 @@
# Text Recognizer
-Implementing the text recognizer project from the course ["Full Stack Deep Learning Course"](https://fullstackdeeplearning.com/march2019) in PyTorch in order to learn best practices when building a deep learning project. I have expanded on this project by adding additional feature and ideas given by Claudio Jolowicz in ["Hypermodern Python"](https://cjolowicz.github.io/posts/hypermodern-python-01-setup/).
+Implementing the text recognizer project from the course ["Full Stack Deep Learning Course"](https://fullstackdeeplearning.com/march2019) (FSDL) in PyTorch in order to learn best practices when building a deep learning project. I have expanded on this project by adding additional feature and ideas given by Claudio Jolowicz in ["Hypermodern Python"](https://cjolowicz.github.io/posts/hypermodern-python-01-setup/).
## Setup
TBC
-## Todo
-- [x] subsampling
-- [x] Be able to run experiments
-- [x] Train models
-- [x] Fix input size in base model
-- [x] Fix s.t. the best weights are saved
-- [x] Implement total training time
-- [x] Fix tqdm and logging output
-- [x] Fix basic test to load model
-- [x] Fix loading previous experiments
-- [x] Able to set verbosity level on the logger to terminal output
-- [x] Implement Callbacks for training
- - [x] Implement early stopping
- - [x] Implement wandb
- - [x] Implement lr scheduler as a callback
- - [x] Implement save checkpoint callback
- - [x] Implement TQDM progress bar (Low priority)
-- [ ] Check that dataset exists, otherwise download it form the web. Do this in run_experiment.py.
-- [x] Create repr func for data loaders
-- [ ] Be able to restart with lr scheduler (May skip this)
-- [ ] Implement population based training
-- [x] Implement Bayesian hyperparameter search (with W&B maybe)
-- [x] Try to fix shell cmd security issues S404, S602
-- [x] Change prepare_experiment.py to print statements st it can be run with tasks/prepare_sample_experiments.sh | parallel -j1
-- [x] Fix caption in WandbImageLogger
-- [x] Rename val_accuracy in metric
-- [x] Start implementing callback list stuff in train.py
-- [x] Fix s.t. callbacks can be loaded in run_experiment.py
-- [x] Lift out Emnist dataset out of Emnist dataloaders
-- [x] Finish Emnist line dataset
-- [x] SentenceGenerator
-- [x] Write a Emnist line data loader
-- [x] Implement ctc line model
- - [x] Implement CNN encoder (ResNet style)
- - [x] Implement the RNN + output layer
- - [x] Construct/implement the CTC loss
-- [x] Sweep base config yaml file
-- [x] sweep.py
-- [x] sweep.yaml
-- [x] Fix dataset splits.
-- [x] Implement predict on image
-- [x] CTC decoder
-- [x] IAM dataset
-- [x] IAM Lines dataset
-- [x] IAM paragraphs dataset
-- [ ] CNN + Transformer (!!)
-- [ ] CNN + GPT
-- [ ] fix nosec problem
-- [x] common Dataset class
-- [x] Fix CTC blank stuff and varying length
-- [x] Metric Learning for backbone training
+
+
+## Todo
+- [ ] create wordpieces
+ - [x] make_wordpieces.py
+ - [x] build_transitions.py
+ - [ ] transform that encodes iam targets to wordpieces
+ - [ ] transducer loss function
+- [ ] Predictive coding
+ - https://arxiv.org/pdf/1807.03748.pdf
+ - https://arxiv.org/pdf/1904.05862.pdf
+ - https://arxiv.org/pdf/1910.05453.pdf
+ - https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
+ - [ ]
+
## Run Sweeps
Run the following commands to execute hyperparameter search with W&B:
diff --git a/poetry.lock b/poetry.lock
index c0c061c..7f715d8 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -15,7 +15,7 @@ python-versions = "*"
version = "1.4.4"
[[package]]
-category = "dev"
+category = "main"
description = "Disable App Nap on OS X 10.9"
marker = "sys_platform == \"darwin\" or platform_system == \"Darwin\" or python_version >= \"3.3\" and sys_platform == \"darwin\""
name = "appnope"
@@ -24,7 +24,7 @@ python-versions = "*"
version = "0.1.0"
[[package]]
-category = "dev"
+category = "main"
description = "The secure Argon2 password hashing algorithm."
name = "argon2-cffi"
optional = false
@@ -41,7 +41,7 @@ docs = ["sphinx"]
tests = ["coverage (>=5.0.2)", "hypothesis", "pytest"]
[[package]]
-category = "dev"
+category = "main"
description = "Async generators and context managers for Python 3.5+"
name = "async-generator"
optional = false
@@ -83,7 +83,7 @@ version = "2.9.0"
pytz = ">=2015.7"
[[package]]
-category = "dev"
+category = "main"
description = "Specifications for callback functions passed in to an API"
name = "backcall"
optional = false
@@ -126,7 +126,7 @@ typed-ast = ">=1.4.0"
d = ["aiohttp (>=3.3.2)", "aiohttp-cors"]
[[package]]
-category = "dev"
+category = "main"
description = "An easy safelist-based HTML-sanitizing tool."
name = "bleach"
optional = false
@@ -166,7 +166,7 @@ python-versions = "*"
version = "2020.11.8"
[[package]]
-category = "dev"
+category = "main"
description = "Foreign Function Interface for Python calling C code."
name = "cffi"
optional = false
@@ -245,8 +245,8 @@ category = "main"
description = "A utility for ensuring Google-style docstrings stay up to date with the source code."
name = "darglint"
optional = false
-python-versions = ">=3.5,<4.0"
-version = "1.5.5"
+python-versions = ">=3.6,<4.0"
+version = "1.5.6"
[[package]]
category = "main"
@@ -257,7 +257,7 @@ python-versions = "*"
version = "0.6"
[[package]]
-category = "dev"
+category = "main"
description = "Decorators for Humans"
name = "decorator"
optional = false
@@ -277,17 +277,17 @@ category = "main"
description = "Deserialize to objects while staying DRY"
name = "desert"
optional = false
-python-versions = "*"
-version = "2020.1.6"
+python-versions = ">=3.6"
+version = "2020.11.18"
[package.dependencies]
attrs = "*"
-dataclasses = "*"
marshmallow = ">=3.0"
typing-inspect = "*"
[package.extras]
-dev = ["coverage", "cuvner", "pytest", "tox", "versioneer", "black", "pylint", "pex", "bump2version", "docutils", "check-manifest", "readme-renderer", "pygments", "isort", "mypy", "pytest-sphinx", "towncrier", "marshmallow-union", "marshmallow-enum", "twine", "wheel"]
+dev = ["coverage", "cuvner", "marshmallow-enum", "marshmallow-union", "pytest", "pytest-cov", "pytest-sphinx", "pytest-travis-fold", "tox", "importlib-metadata", "versioneer", "black", "pylint", "pex", "bump2version", "docutils", "check-manifest", "readme-renderer", "pygments", "isort", "mypy", "towncrier", "twine", "wheel"]
+test = ["coverage", "cuvner", "marshmallow-enum", "marshmallow-union", "pytest", "pytest-cov", "pytest-sphinx", "pytest-travis-fold", "tox", "importlib-metadata"]
[[package]]
category = "main"
@@ -330,10 +330,10 @@ description = "A new flavour of deep learning operations"
name = "einops"
optional = false
python-versions = "*"
-version = "0.2.0"
+version = "0.3.0"
[[package]]
-category = "dev"
+category = "main"
description = "Discover and load entry points from installed packages."
name = "entrypoints"
optional = false
@@ -353,10 +353,6 @@ mccabe = ">=0.6.0,<0.7.0"
pycodestyle = ">=2.6.0a1,<2.7.0"
pyflakes = ">=2.2.0,<2.3.0"
-[package.dependencies.importlib-metadata]
-python = "<3.8"
-version = "*"
-
[[package]]
category = "main"
description = "Flake8 Type Annotation Checks"
@@ -368,10 +364,6 @@ version = "2.4.1"
[package.dependencies]
flake8 = ">=3.7,<3.9"
-[package.dependencies.typed-ast]
-python = "<3.8"
-version = ">=1.4,<2.0"
-
[[package]]
category = "dev"
description = "Automated security testing with bandit and flake8."
@@ -493,34 +485,25 @@ six = ">=1.7"
test = ["mock (>=2.0.0)", "pytest (<5.0)"]
[[package]]
-category = "main"
-description = "GraphQL client for Python"
-name = "gql"
+category = "dev"
+description = "Simple Python interface for Graphviz"
+name = "graphviz"
optional = false
-python-versions = "*"
-version = "0.2.0"
+python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
+version = "0.16"
-[package.dependencies]
-graphql-core = ">=0.5.0,<2"
-promise = ">=2.0,<3"
-requests = ">=2.12,<3"
-six = ">=1.10.0"
+[package.extras]
+dev = ["tox (>=3)", "flake8", "pep8-naming", "wheel", "twine"]
+docs = ["sphinx (>=1.8)", "sphinx-rtd-theme"]
+test = ["mock (>=3)", "pytest (>=4)", "pytest-mock (>=2)", "pytest-cov"]
[[package]]
category = "main"
-description = "GraphQL implementation for Python"
-name = "graphql-core"
+description = "Automatic differentiation with WFSTs"
+name = "gtn"
optional = false
-python-versions = "*"
-version = "1.1"
-
-[package.dependencies]
-promise = ">=2.0"
-six = ">=1.10.0"
-
-[package.extras]
-gevent = ["gevent (1.1rc1)"]
-test = ["pytest (3.0.2)", "pytest-django (2.9.1)", "pytest-cov (2.3.1)", "coveralls", "gevent (1.1rc1)", "six (>=1.10.0)", "pytest-benchmark (3.0.0)", "pytest-mock (1.2)"]
+python-versions = ">=3.5"
+version = "0.0.0"
[[package]]
category = "main"
@@ -551,36 +534,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "1.2.0"
[[package]]
-category = "dev"
-description = "A library to calculate python dependency graphs."
-marker = "python_version == \"3.7\""
-name = "importlab"
-optional = false
-python-versions = ">=2.7.0"
-version = "0.5.1"
-
-[package.dependencies]
-networkx = "*"
-six = "*"
-
-[[package]]
category = "main"
-description = "Read metadata from Python packages"
-marker = "python_version < \"3.8\""
-name = "importlib-metadata"
-optional = false
-python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
-version = "2.0.0"
-
-[package.dependencies]
-zipp = ">=0.5"
-
-[package.extras]
-docs = ["sphinx", "rst.linker"]
-testing = ["packaging", "pep517", "importlib-resources (>=1.3)"]
-
-[[package]]
-category = "dev"
description = "IPython Kernel for Jupyter"
name = "ipykernel"
optional = false
@@ -598,7 +552,7 @@ traitlets = ">=4.1.0"
test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "nose"]
[[package]]
-category = "dev"
+category = "main"
description = "IPython: Productive Interactive Computing"
name = "ipython"
optional = false
@@ -630,7 +584,7 @@ qtconsole = ["qtconsole"]
test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.14)"]
[[package]]
-category = "dev"
+category = "main"
description = "Vestigial utilities from IPython"
name = "ipython-genutils"
optional = false
@@ -659,7 +613,7 @@ version = ">=4.0.0"
test = ["pytest (>=3.6.0)", "pytest-cov", "mock"]
[[package]]
-category = "dev"
+category = "main"
description = "An autocompletion tool for Python that can be used for text editors."
name = "jedi"
optional = false
@@ -696,7 +650,7 @@ python-versions = ">=3.6"
version = "0.17.0"
[[package]]
-category = "dev"
+category = "main"
description = "An implementation of JSON Schema validation for Python"
name = "jsonschema"
optional = false
@@ -709,10 +663,6 @@ pyrsistent = ">=0.14.0"
setuptools = "*"
six = ">=1.11.0"
-[package.dependencies.importlib-metadata]
-python = "<3.8"
-version = "*"
-
[package.extras]
format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors"]
format_nongpl = ["idna", "jsonpointer (>1.13)", "webcolors", "rfc3986-validator (>0.1.0)", "rfc3339-validator"]
@@ -734,7 +684,7 @@ notebook = "*"
qtconsole = "*"
[[package]]
-category = "dev"
+category = "main"
description = "Jupyter protocol implementation and client libraries"
name = "jupyter-client"
optional = false
@@ -770,19 +720,19 @@ pygments = "*"
test = ["pexpect"]
[[package]]
-category = "dev"
+category = "main"
description = "Jupyter core package. A base package on which Jupyter projects rely."
name = "jupyter-core"
optional = false
-python-versions = "!=3.0,!=3.1,!=3.2,!=3.3,!=3.4,>=2.7"
-version = "4.6.3"
+python-versions = ">=3.6"
+version = "4.7.0"
[package.dependencies]
pywin32 = ">=1.0"
traitlets = "*"
[[package]]
-category = "dev"
+category = "main"
description = "Pygments theme using JupyterLab CSS variables"
name = "jupyterlab-pygments"
optional = false
@@ -794,6 +744,21 @@ pygments = ">=2.4.1,<3"
[[package]]
category = "main"
+description = "Select and install a Jupyter notebook theme"
+name = "jupyterthemes"
+optional = false
+python-versions = "*"
+version = "0.20.0"
+
+[package.dependencies]
+ipython = ">=5.4.1"
+jupyter-core = "*"
+lesscpy = ">=0.11.2"
+matplotlib = ">=1.4.3"
+notebook = ">=5.6.0"
+
+[[package]]
+category = "main"
description = "A fast implementation of the Cassowary constraint solver"
name = "kiwisolver"
optional = false
@@ -802,6 +767,18 @@ version = "1.3.1"
[[package]]
category = "main"
+description = "Python LESS compiler"
+name = "lesscpy"
+optional = false
+python-versions = "*"
+version = "0.14.0"
+
+[package.dependencies]
+ply = "*"
+six = "*"
+
+[[package]]
+category = "main"
description = "Python logging made (stupidly) simple"
name = "loguru"
optional = false
@@ -862,7 +839,7 @@ python-versions = "*"
version = "0.6.1"
[[package]]
-category = "dev"
+category = "main"
description = "The fastest markdown parser in pure Python"
name = "mistune"
optional = false
@@ -902,7 +879,7 @@ python-versions = "*"
version = "0.4.3"
[[package]]
-category = "dev"
+category = "main"
description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
name = "nbclient"
optional = false
@@ -922,7 +899,7 @@ sphinx = ["Sphinx (>=1.7)", "sphinx-book-theme", "mock", "moto", "myst-parser"]
test = ["codecov", "coverage", "ipython", "ipykernel", "ipywidgets", "pytest (>=4.1)", "pytest-cov (>=2.6.1)", "check-manifest", "flake8", "mypy", "tox", "bumpversion", "xmltodict", "pip (>=18.1)", "wheel (>=0.31.0)", "setuptools (>=38.6.0)", "twine (>=1.11.0)", "black"]
[[package]]
-category = "dev"
+category = "main"
description = "Converting Jupyter Notebooks"
name = "nbconvert"
optional = false
@@ -952,7 +929,7 @@ test = ["pytest", "pytest-cov", "pytest-dependency", "ipykernel", "ipywidgets (>
webpdf = ["pyppeteer (0.2.2)"]
[[package]]
-category = "dev"
+category = "main"
description = "The Jupyter Notebook format"
name = "nbformat"
optional = false
@@ -970,7 +947,7 @@ fast = ["fastjsonschema"]
test = ["fastjsonschema", "testpath", "pytest", "pytest-cov"]
[[package]]
-category = "dev"
+category = "main"
description = "Patch asyncio to allow nested event loops"
name = "nest-asyncio"
optional = false
@@ -978,40 +955,6 @@ python-versions = ">=3.5"
version = "1.4.3"
[[package]]
-category = "dev"
-description = "Python package for creating and manipulating graphs and networks"
-marker = "python_version == \"3.7\""
-name = "networkx"
-optional = false
-python-versions = ">=3.6"
-version = "2.5"
-
-[package.dependencies]
-decorator = ">=4.3.0"
-
-[package.extras]
-all = ["numpy", "scipy", "pandas", "matplotlib", "pygraphviz", "pydot", "pyyaml", "lxml", "pytest"]
-gdal = ["gdal"]
-lxml = ["lxml"]
-matplotlib = ["matplotlib"]
-numpy = ["numpy"]
-pandas = ["pandas"]
-pydot = ["pydot"]
-pygraphviz = ["pygraphviz"]
-pytest = ["pytest"]
-pyyaml = ["pyyaml"]
-scipy = ["scipy"]
-
-[[package]]
-category = "dev"
-description = "Ninja is a small build system with a focus on speed"
-marker = "python_version == \"3.7\""
-name = "ninja"
-optional = false
-python-versions = "*"
-version = "1.10.0.post2"
-
-[[package]]
category = "main"
description = "Natural Language Toolkit"
name = "nltk"
@@ -1034,7 +977,7 @@ tgrep = ["pyparsing"]
twitter = ["twython"]
[[package]]
-category = "dev"
+category = "main"
description = "A web-based notebook environment for interactive computing"
name = "notebook"
optional = false
@@ -1070,7 +1013,7 @@ python-versions = ">=3.6"
version = "1.19.4"
[[package]]
-category = "main"
+category = "dev"
description = "Python Bindings for the NVIDIA Management Library"
name = "nvidia-ml-py3"
optional = false
@@ -1110,7 +1053,7 @@ pyparsing = ">=2.0.2"
six = "*"
[[package]]
-category = "dev"
+category = "main"
description = "Utilities for writing pandoc filters in python"
name = "pandocfilters"
optional = false
@@ -1118,9 +1061,8 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "1.4.3"
[[package]]
-category = "dev"
+category = "main"
description = "A Python Parser"
-marker = "python_version >= \"3.3\""
name = "parso"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
@@ -1154,7 +1096,7 @@ python-versions = ">=2.6"
version = "5.5.1"
[[package]]
-category = "dev"
+category = "main"
description = "Pexpect allows easy control of interactive console applications."
marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\""
name = "pexpect"
@@ -1166,7 +1108,7 @@ version = "4.8.0"
ptyprocess = ">=0.5"
[[package]]
-category = "dev"
+category = "main"
description = "Tiny 'shelve'-like database with concurrency support"
name = "pickleshare"
optional = false
@@ -1189,16 +1131,19 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "0.13.1"
-[package.dependencies]
-[package.dependencies.importlib-metadata]
-python = "<3.8"
-version = ">=0.12"
-
[package.extras]
dev = ["pre-commit", "tox"]
[[package]]
-category = "dev"
+category = "main"
+description = "Python Lex & Yacc"
+name = "ply"
+optional = false
+python-versions = "*"
+version = "3.11"
+
+[[package]]
+category = "main"
description = "Python client for the Prometheus monitoring system."
name = "prometheus-client"
optional = false
@@ -1223,7 +1168,7 @@ six = "*"
test = ["pytest (>=2.7.3)", "pytest-cov", "coveralls", "futures", "pytest-benchmark", "mock"]
[[package]]
-category = "dev"
+category = "main"
description = "Library for building powerful interactive command lines in Python"
name = "prompt-toolkit"
optional = false
@@ -1235,6 +1180,17 @@ wcwidth = "*"
[[package]]
category = "main"
+description = "Protocol Buffers"
+name = "protobuf"
+optional = false
+python-versions = "*"
+version = "3.14.0"
+
+[package.dependencies]
+six = ">=1.9"
+
+[[package]]
+category = "main"
description = "Cross-platform lib for process and system monitoring in Python."
name = "psutil"
optional = false
@@ -1245,9 +1201,9 @@ version = "5.7.3"
test = ["ipaddress", "mock", "unittest2", "enum34", "pywin32", "wmi"]
[[package]]
-category = "dev"
+category = "main"
description = "Run a subprocess in a pseudo terminal"
-marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or os_name != \"nt\""
+marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\" or os_name != \"nt\" or python_version >= \"3.3\" and sys_platform != \"win32\" and (python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\")"
name = "ptyprocess"
optional = false
python-versions = "*"
@@ -1270,7 +1226,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "2.6.0"
[[package]]
-category = "dev"
+category = "main"
description = "C parser in Python"
name = "pycparser"
optional = false
@@ -1313,7 +1269,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
version = "2.4.7"
[[package]]
-category = "dev"
+category = "main"
description = "Persistent/Functional/Immutable data structures"
name = "pyrsistent"
optional = false
@@ -1338,10 +1294,6 @@ pluggy = ">=0.12,<1.0"
py = ">=1.5.0"
wcwidth = "*"
-[package.dependencies.importlib-metadata]
-python = "<3.8"
-version = ">=0.12"
-
[package.extras]
checkqa-mypy = ["mypy (v0.761)"]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
@@ -1399,14 +1351,6 @@ setuptools = "*"
[[package]]
category = "main"
-description = "PyTorch extension for fast block sparse matrices computation, drop in replacement for torch.nn.Linear."
-name = "pytorch-block-sparse"
-optional = false
-python-versions = "*"
-version = "0.1.2"
-
-[[package]]
-category = "main"
description = "The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch."
name = "pytorch-metric-learning"
optional = false
@@ -1426,23 +1370,6 @@ with-hooks = ["record-keeper (>=0.9.29)", "faiss-gpu (>=1.6.3)", "tensorboard"]
with-hooks-cpu = ["record-keeper (>=0.9.29)", "faiss-cpu (>=1.6.3)", "tensorboard"]
[[package]]
-category = "dev"
-description = "Python type inferencer"
-marker = "python_version == \"3.7\""
-name = "pytype"
-optional = false
-python-versions = "<3.9,>=3.6"
-version = "2020.11.12"
-
-[package.dependencies]
-attrs = "*"
-importlab = ">=0.5.1"
-ninja = ">=1.10.0.post2"
-pyyaml = ">=3.11"
-six = "*"
-typed_ast = "*"
-
-[[package]]
category = "main"
description = "World timezone definitions, modern and historical"
name = "pytz"
@@ -1451,7 +1378,7 @@ python-versions = "*"
version = "2020.4"
[[package]]
-category = "dev"
+category = "main"
description = "Python for Window Extensions"
marker = "sys_platform == \"win32\""
name = "pywin32"
@@ -1460,7 +1387,7 @@ python-versions = "*"
version = "300"
[[package]]
-category = "dev"
+category = "main"
description = "Python bindings for the winpty library"
marker = "os_name == \"nt\""
name = "pywinpty"
@@ -1477,7 +1404,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "5.3.1"
[[package]]
-category = "dev"
+category = "main"
description = "Python bindings for 0MQ"
name = "pyzmq"
optional = false
@@ -1610,7 +1537,7 @@ version = "1.5.4"
numpy = ">=1.14.5"
[[package]]
-category = "dev"
+category = "main"
description = "Send file to trash natively under Mac OS X, Windows and Linux."
name = "send2trash"
optional = false
@@ -1619,11 +1546,19 @@ version = "1.5.0"
[[package]]
category = "main"
+description = "SentencePiece python wrapper"
+name = "sentencepiece"
+optional = false
+python-versions = "*"
+version = "0.1.95"
+
+[[package]]
+category = "main"
description = "Python client for Sentry (https://sentry.io)"
name = "sentry-sdk"
optional = false
python-versions = "*"
-version = "0.19.3"
+version = "0.19.4"
[package.dependencies]
certifi = "*"
@@ -1817,10 +1752,6 @@ version = "3.2.2"
[package.dependencies]
pbr = ">=2.0.0,<2.1.0 || >2.1.0"
-[package.dependencies.importlib-metadata]
-python = "<3.8"
-version = ">=1.7.0"
-
[[package]]
category = "main"
description = "A backport of the subprocess module from Python 3 for use on 2.x."
@@ -1830,7 +1761,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, <4"
version = "3.5.4"
[[package]]
-category = "dev"
+category = "main"
description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
name = "terminado"
optional = false
@@ -1843,7 +1774,7 @@ pywinpty = ">=0.5"
tornado = ">=4"
[[package]]
-category = "dev"
+category = "main"
description = "Test utilities for code working with files and commands"
name = "testpath"
optional = false
@@ -1908,7 +1839,7 @@ torch = "1.7.0"
scipy = ["scipy"]
[[package]]
-category = "dev"
+category = "main"
description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
name = "tornado"
optional = false
@@ -1920,14 +1851,14 @@ category = "main"
description = "Fast, Extensible Progress Meter"
name = "tqdm"
optional = false
-python-versions = ">=2.6, !=3.0.*, !=3.1.*"
-version = "4.52.0"
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
+version = "4.53.0"
[package.extras]
dev = ["py-make (>=0.1.0)", "twine", "argopt", "pydoc-markdown", "wheel"]
[[package]]
-category = "dev"
+category = "main"
description = "Traitlets Python configuration system"
name = "traitlets"
optional = false
@@ -1941,7 +1872,7 @@ ipython-genutils = "*"
test = ["pytest"]
[[package]]
-category = "main"
+category = "dev"
description = "a fork of Python 2 and 3 ast modules with type comment support"
name = "typed-ast"
optional = false
@@ -1999,28 +1930,29 @@ description = "A CLI and library for interacting with the Weights and Biases API
name = "wandb"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-version = "0.9.7"
+version = "0.10.12"
[package.dependencies]
Click = ">=7.0"
GitPython = ">=1.0.0"
-PyYAML = ">=3.10"
+PyYAML = "*"
configparser = ">=3.8.1"
docker-pycreds = ">=0.4.0"
-gql = "0.2.0"
-nvidia-ml-py3 = ">=7.352.0"
+promise = ">=2.0,<3"
+protobuf = ">=3.12.0"
psutil = ">=5.0.0"
python-dateutil = ">=2.6.1"
-requests = ">=2.0.0"
+requests = ">=2.0.0,<3"
sentry-sdk = ">=0.4.0"
shortuuid = ">=0.5.0"
-six = ">=1.10.0"
+six = ">=1.13.0"
subprocess32 = ">=3.5.3"
watchdog = ">=0.8.3"
[package.extras]
aws = ["boto3"]
gcp = ["google-cloud-storage"]
+grpc = ["grpcio (1.27.2)"]
kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"]
[[package]]
@@ -2029,7 +1961,7 @@ description = "Filesystem events monitoring"
name = "watchdog"
optional = false
python-versions = "*"
-version = "0.10.3"
+version = "0.10.4"
[package.dependencies]
pathtools = ">=0.1.1"
@@ -2046,7 +1978,7 @@ python-versions = "*"
version = "0.2.5"
[[package]]
-category = "dev"
+category = "main"
description = "Character encoding aliases for legacy web content"
name = "webencodings"
optional = false
@@ -2092,23 +2024,10 @@ all = ["six", "pytest", "pytest-cov", "codecov", "scikit-build", "cmake", "ninja
optional = ["pygments", "colorama"]
tests = ["pytest", "pytest-cov", "codecov", "scikit-build", "cmake", "ninja", "pybind11"]
-[[package]]
-category = "main"
-description = "Backport of pathlib-compatible object wrapper for zip files"
-marker = "python_version < \"3.8\""
-name = "zipp"
-optional = false
-python-versions = ">=3.6"
-version = "3.4.0"
-
-[package.extras]
-docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"]
-testing = ["pytest (>=3.5,<3.7.3 || >3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "jaraco.test (>=3.2.0)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"]
-
[metadata]
-content-hash = "9e2ebd8ad14a53756cf4d2b967e4dfb52d1c70caa0a6df6b5e42600237337631"
+content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0"
lock-version = "1.0"
-python-versions = "^3.7"
+python-versions = "^3.8"
[metadata.files]
alabaster = [
@@ -2281,8 +2200,8 @@ cycler = [
{file = "cycler-0.10.0.tar.gz", hash = "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8"},
]
darglint = [
- {file = "darglint-1.5.5-py3-none-any.whl", hash = "sha256:cd882c812f28ee3b5577259bfd8d6d25962386dd87fc1f3756eac24370aaa060"},
- {file = "darglint-1.5.5.tar.gz", hash = "sha256:2f12ce2ef3d8189279a8f2eb4c53fd215dbacae50e37765542a91310400a9cd6"},
+ {file = "darglint-1.5.6-py3-none-any.whl", hash = "sha256:6fcef385e646c4da9ea6fc547e28c77a33ae0cba4806b8585ae18a490a797e82"},
+ {file = "darglint-1.5.6.tar.gz", hash = "sha256:98acb4064bae73ec02146cb123dd3c930bd5272e562ad4d19c59857443632dd1"},
]
dataclasses = [
{file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
@@ -2297,8 +2216,8 @@ defusedxml = [
{file = "defusedxml-0.6.0.tar.gz", hash = "sha256:f684034d135af4c6cbb949b8a4d2ed61634515257a67299e5f940fbaa34377f5"},
]
desert = [
- {file = "desert-2020.1.6-py2.py3-none-any.whl", hash = "sha256:190ab1c690472ab1c1ef7614f9a73171c1e911cef42d7b45a67f9b7d6900763d"},
- {file = "desert-2020.1.6.tar.gz", hash = "sha256:e64cd61e16607bb3096ab1b1763c9f229ebcf8b3f871f4db52fb803e1c000385"},
+ {file = "desert-2020.11.18-py3-none-any.whl", hash = "sha256:6392702be7952fb9c8bbc775425fa929c1eab5c06552c2890e82a24964eeb084"},
+ {file = "desert-2020.11.18.tar.gz", hash = "sha256:d7b7fb521dc84eec955a766ed7e37349f998cf047f37fd9596cb09737d63c62d"},
]
docker-pycreds = [
{file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"},
@@ -2313,8 +2232,8 @@ dparse = [
{file = "dparse-0.5.1.tar.gz", hash = "sha256:a1b5f169102e1c894f9a7d5ccf6f9402a836a5d24be80a986c7ce9eaed78f367"},
]
einops = [
- {file = "einops-0.2.0-py2.py3-none-any.whl", hash = "sha256:96b1bac57ddb591cccb927d24934d7601c3cdf3343a79a43d316a118d66e1043"},
- {file = "einops-0.2.0.tar.gz", hash = "sha256:165ee28bcb60e5c2cbb801b5c78e181548ff8daa7c8fcabae5b251e55f7fe614"},
+ {file = "einops-0.3.0-py2.py3-none-any.whl", hash = "sha256:a91c6190ceff7d513d74ca9fd701dfa6a1ffcdd98ea0ced14350197c07f75c73"},
+ {file = "einops-0.3.0.tar.gz", hash = "sha256:a3b0935a4556f012cd5fa1851373f63366890a3f6698d117afea55fd2a40c1fc"},
]
entrypoints = [
{file = "entrypoints-0.3-py2.py3-none-any.whl", hash = "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19"},
@@ -2364,11 +2283,12 @@ gitpython = [
gpustat = [
{file = "gpustat-0.6.0.tar.gz", hash = "sha256:f69135080b2668b662822633312c2180002c10111597af9631bb02e042755b6c"},
]
-gql = [
- {file = "gql-0.2.0.tar.gz", hash = "sha256:ad0f0b8226428d727c8e1d1cac4e521d83ed024d814921bd55b8adb997dadf4b"},
+graphviz = [
+ {file = "graphviz-0.16-py2.py3-none-any.whl", hash = "sha256:3cad5517c961090dfc679df6402a57de62d97703e2880a1a46147bb0dc1639eb"},
+ {file = "graphviz-0.16.zip", hash = "sha256:d2d25af1c199cad567ce4806f0449cb74eb30cf451fd7597251e1da099ac6e57"},
]
-graphql-core = [
- {file = "graphql-core-1.1.tar.gz", hash = "sha256:63bb8593aeeadb0a53e14207b910027fe51158d017927fad87326dac806185ee"},
+gtn = [
+ {file = "gtn-0.0.0.tar.gz", hash = "sha256:72fece9ca51df161c1274e570d6f5f933e76f4cac9d8d6dd543a3fe0383f7268"},
]
h5py = [
{file = "h5py-2.10.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ecf4d0b56ee394a0984de15bceeb97cbe1fe485f1ac205121293fc44dcf3f31f"},
@@ -2409,13 +2329,6 @@ imagesize = [
{file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"},
{file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"},
]
-importlab = [
- {file = "importlab-0.5.1.tar.gz", hash = "sha256:d855350d19dc10a17aabd2fe6f4b428ff1a936071f692fbf686a73694d26a51c"},
-]
-importlib-metadata = [
- {file = "importlib_metadata-2.0.0-py2.py3-none-any.whl", hash = "sha256:cefa1a2f919b866c5beb7c9f7b0ebb4061f30a8a9bf16d609b000e2dfaceb9c3"},
- {file = "importlib_metadata-2.0.0.tar.gz", hash = "sha256:77a540690e24b0305878c37ffd421785a6f7e53c8b5720d211b211de8d0e95da"},
-]
ipykernel = [
{file = "ipykernel-5.3.4-py3-none-any.whl", hash = "sha256:d6fbba26dba3cebd411382bc484f7bc2caa98427ae0ddb4ab37fe8bfeb5c7dd3"},
{file = "ipykernel-5.3.4.tar.gz", hash = "sha256:9b2652af1607986a1b231c62302d070bc0534f564c393a5d9d130db9abbbe89d"},
@@ -2462,13 +2375,17 @@ jupyter-console = [
{file = "jupyter_console-6.2.0.tar.gz", hash = "sha256:7f6194f4f4692d292da3f501c7f343ccd5e36c6a1becf7b7515e23e66d6bf1e9"},
]
jupyter-core = [
- {file = "jupyter_core-4.6.3-py2.py3-none-any.whl", hash = "sha256:a4ee613c060fe5697d913416fc9d553599c05e4492d58fac1192c9a6844abb21"},
- {file = "jupyter_core-4.6.3.tar.gz", hash = "sha256:394fd5dd787e7c8861741880bdf8a00ce39f95de5d18e579c74b882522219e7e"},
+ {file = "jupyter_core-4.7.0-py3-none-any.whl", hash = "sha256:0a451c9b295e4db772bdd8d06f2f1eb31caeec0e81fbb77ba37d4a3024e3b315"},
+ {file = "jupyter_core-4.7.0.tar.gz", hash = "sha256:aa1f9496ab3abe72da4efe0daab0cb2233997914581f9a071e07498c6add8ed3"},
]
jupyterlab-pygments = [
{file = "jupyterlab_pygments-0.1.2-py2.py3-none-any.whl", hash = "sha256:abfb880fd1561987efaefcb2d2ac75145d2a5d0139b1876d5be806e32f630008"},
{file = "jupyterlab_pygments-0.1.2.tar.gz", hash = "sha256:cfcda0873626150932f438eccf0f8bf22bfa92345b814890ab360d666b254146"},
]
+jupyterthemes = [
+ {file = "jupyterthemes-0.20.0-py2.py3-none-any.whl", hash = "sha256:4bd42fc88a06e3afabbe70c2ee25e6467147512993a3cbd9bec57ae3fd2e2fb1"},
+ {file = "jupyterthemes-0.20.0.tar.gz", hash = "sha256:2a8ebc0c84b212ab99b9f1757fc0582a3f53930d3a75b2492d91a7c8b36ab41e"},
+]
kiwisolver = [
{file = "kiwisolver-1.3.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:fd34fbbfbc40628200730bc1febe30631347103fc8d3d4fa012c21ab9c11eca9"},
{file = "kiwisolver-1.3.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:d3155d828dec1d43283bd24d3d3e0d9c7c350cdfcc0bd06c0ad1209c1bbc36d0"},
@@ -2503,6 +2420,10 @@ kiwisolver = [
{file = "kiwisolver-1.3.1-pp36-pypy36_pp73-win32.whl", hash = "sha256:401a2e9afa8588589775fe34fc22d918ae839aaaf0c0e96441c0fdbce6d8ebe6"},
{file = "kiwisolver-1.3.1.tar.gz", hash = "sha256:950a199911a8d94683a6b10321f9345d5a3a8433ec58b217ace979e18f16e248"},
]
+lesscpy = [
+ {file = "lesscpy-0.14.0-py2.py3-none-any.whl", hash = "sha256:b0f2f853ee1dfb0891b147b57028057d5389510e079581e7b533d07dc0d95d3e"},
+ {file = "lesscpy-0.14.0.tar.gz", hash = "sha256:7b664f60818a16afa8cc9f1dd6d9b17f944e0ce94e50787d76f81bc7a8648cce"},
+]
loguru = [
{file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"},
{file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"},
@@ -2621,23 +2542,6 @@ nest-asyncio = [
{file = "nest_asyncio-1.4.3-py3-none-any.whl", hash = "sha256:dbe032f3e9ff7f120e76be22bf6e7958e867aed1743e6894b8a9585fe8495cc9"},
{file = "nest_asyncio-1.4.3.tar.gz", hash = "sha256:eaa09ef1353ebefae19162ad423eef7a12166bcc63866f8bff8f3635353cd9fa"},
]
-networkx = [
- {file = "networkx-2.5-py3-none-any.whl", hash = "sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c"},
- {file = "networkx-2.5.tar.gz", hash = "sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602"},
-]
-ninja = [
- {file = "ninja-1.10.0.post2-py2-none-macosx_10_6_x86_64.whl", hash = "sha256:a1a9d9455623a3f45557fff6eb5abb3e70910dde28cfb9239e3ca14249149f55"},
- {file = "ninja-1.10.0.post2-py2-none-manylinux1_i686.whl", hash = "sha256:99c6102ae9a8981afe4d06f92508dbeab1e28ec89783fb703411166f4e13c9ee"},
- {file = "ninja-1.10.0.post2-py2-none-manylinux1_x86_64.whl", hash = "sha256:4252ce532304841e47478bb61710fcf9940cf2c91731303490762b6e4f23fd2b"},
- {file = "ninja-1.10.0.post2-py2-none-win32.whl", hash = "sha256:24acc95359308d11243386cf9f076bdc95f438ef6a4e0e357e7c122c5e02816d"},
- {file = "ninja-1.10.0.post2-py2-none-win_amd64.whl", hash = "sha256:16fc1bea52a36a91a0e80c3b221d2c1bc9bcf04d0564da9344e349b8c5efd5c6"},
- {file = "ninja-1.10.0.post2-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:1d9ed3b5fdeb646516f54bec92453dcb3000d6771c2fea56451444c988a23e29"},
- {file = "ninja-1.10.0.post2-py3-none-manylinux1_i686.whl", hash = "sha256:5c3a8cb54aaaf5d4f692d65121ef47b3e43dea123a6563153d9d97631c0adf4f"},
- {file = "ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl", hash = "sha256:fb1ae96811a9b73773014b8a21d710b89d7d5f765427a5e2541e7fb9d530fdd5"},
- {file = "ninja-1.10.0.post2-py3-none-win32.whl", hash = "sha256:06a72090f5c5516e57f12699644179504a77585bed6d5f8be9e67219a398ec80"},
- {file = "ninja-1.10.0.post2-py3-none-win_amd64.whl", hash = "sha256:c6059bd04ad235e2326b39bc71bb7989de8d565084b5f269557704747b2910fa"},
- {file = "ninja-1.10.0.post2.tar.gz", hash = "sha256:621fd73513a9bef0cb82e8c531a29ef96580b4d6e797f833cce167054ad812f8"},
-]
nltk = [
{file = "nltk-3.5.zip", hash = "sha256:845365449cd8c5f9731f7cb9f8bd6fd0767553b9d53af9eb1b3abf7700936b35"},
]
@@ -2775,6 +2679,10 @@ pluggy = [
{file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"},
{file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"},
]
+ply = [
+ {file = "ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce"},
+ {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"},
+]
prometheus-client = [
{file = "prometheus_client-0.9.0-py2.py3-none-any.whl", hash = "sha256:b08c34c328e1bf5961f0b4352668e6c8f145b4a087e09b7296ef62cbe4693d35"},
{file = "prometheus_client-0.9.0.tar.gz", hash = "sha256:9da7b32f02439d8c04f7777021c304ed51d9ec180604700c1ba72a4d44dceb03"},
@@ -2786,6 +2694,26 @@ prompt-toolkit = [
{file = "prompt_toolkit-3.0.8-py3-none-any.whl", hash = "sha256:7debb9a521e0b1ee7d2fe96ee4bd60ef03c6492784de0547337ca4433e46aa63"},
{file = "prompt_toolkit-3.0.8.tar.gz", hash = "sha256:25c95d2ac813909f813c93fde734b6e44406d1477a9faef7c915ff37d39c0a8c"},
]
+protobuf = [
+ {file = "protobuf-3.14.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:629b03fd3caae7f815b0c66b41273f6b1900a579e2ccb41ef4493a4f5fb84f3a"},
+ {file = "protobuf-3.14.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:5b7a637212cc9b2bcf85dd828b1178d19efdf74dbfe1ddf8cd1b8e01fdaaa7f5"},
+ {file = "protobuf-3.14.0-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:43b554b9e73a07ba84ed6cf25db0ff88b1e06be610b37656e292e3cbb5437472"},
+ {file = "protobuf-3.14.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:5e9806a43232a1fa0c9cf5da8dc06f6910d53e4390be1fa06f06454d888a9142"},
+ {file = "protobuf-3.14.0-cp35-cp35m-win32.whl", hash = "sha256:1c51fda1bbc9634246e7be6016d860be01747354ed7015ebe38acf4452f470d2"},
+ {file = "protobuf-3.14.0-cp35-cp35m-win_amd64.whl", hash = "sha256:4b74301b30513b1a7494d3055d95c714b560fbb630d8fb9956b6f27992c9f980"},
+ {file = "protobuf-3.14.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86a75477addde4918e9a1904e5c6af8d7b691f2a3f65587d73b16100fbe4c3b2"},
+ {file = "protobuf-3.14.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ecc33531a213eee22ad60e0e2aaea6c8ba0021f0cce35dbf0ab03dee6e2a23a1"},
+ {file = "protobuf-3.14.0-cp36-cp36m-win32.whl", hash = "sha256:72230ed56f026dd664c21d73c5db73ebba50d924d7ba6b7c0d81a121e390406e"},
+ {file = "protobuf-3.14.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0fc96785262042e4863b3f3b5c429d4636f10d90061e1840fce1baaf59b1a836"},
+ {file = "protobuf-3.14.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4e75105c9dfe13719b7293f75bd53033108f4ba03d44e71db0ec2a0e8401eafd"},
+ {file = "protobuf-3.14.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:2a7e2fe101a7ace75e9327b9c946d247749e564a267b0515cf41dfe450b69bac"},
+ {file = "protobuf-3.14.0-cp37-cp37m-win32.whl", hash = "sha256:b0d5d35faeb07e22a1ddf8dce620860c8fe145426c02d1a0ae2688c6e8ede36d"},
+ {file = "protobuf-3.14.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8971c421dbd7aad930c9bd2694122f332350b6ccb5202a8b7b06f3f1a5c41ed5"},
+ {file = "protobuf-3.14.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9616f0b65a30851e62f1713336c931fcd32c057202b7ff2cfbfca0fc7d5e3043"},
+ {file = "protobuf-3.14.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:22bcd2e284b3b1d969c12e84dc9b9a71701ec82d8ce975fdda19712e1cfd4e00"},
+ {file = "protobuf-3.14.0-py2.py3-none-any.whl", hash = "sha256:0e247612fadda953047f53301a7b0407cb0c3cb4ae25a6fde661597a04039b3c"},
+ {file = "protobuf-3.14.0.tar.gz", hash = "sha256:1d63eb389347293d8915fb47bee0951c7b5dab522a4a60118b9a18f33e21f8ce"},
+]
psutil = [
{file = "psutil-5.7.3-cp27-none-win32.whl", hash = "sha256:1cd6a0c9fb35ece2ccf2d1dd733c1e165b342604c67454fd56a4c12e0a106787"},
{file = "psutil-5.7.3-cp27-none-win_amd64.whl", hash = "sha256:e02c31b2990dcd2431f4524b93491941df39f99619b0d312dfe1d4d530b08b4b"},
@@ -2853,22 +2781,10 @@ python-dateutil = [
python-levenshtein = [
{file = "python-Levenshtein-0.12.0.tar.gz", hash = "sha256:033a11de5e3d19ea25c9302d11224e1a1898fe5abd23c61c7c360c25195e3eb1"},
]
-pytorch-block-sparse = [
- {file = "pytorch_block_sparse-0.1.2.tar.gz", hash = "sha256:ca4a5c1dde96ac01c007f209067b2bbaee311a8699eba1eef712faef7f97df1f"},
-]
pytorch-metric-learning = [
{file = "pytorch-metric-learning-0.9.94.tar.gz", hash = "sha256:523ab08ee10745edc6512cc32b62b4ba0c858906cfd5a2e9e5c9bfa1a6b7daa2"},
{file = "pytorch_metric_learning-0.9.94-py3-none-any.whl", hash = "sha256:3719c380c3b8d90f599c3c7e9fe7410d025b091d389ef7769044a1437096dbcc"},
]
-pytype = [
- {file = "pytype-2020.11.12-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:ea77133d694584caadd7d5e1769797b3a65f5759a18a1ccac6770bed37221b83"},
- {file = "pytype-2020.11.12-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:a711a477e45a13737623a89c87cb43d8edca35c3e27dfb0a511bbd1a6e76da30"},
- {file = "pytype-2020.11.12-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:79933e2229e9b8f8f0de6ab2625a33ed38dd6d74598e49ab74f178fe7c53e0de"},
- {file = "pytype-2020.11.12-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:10acd00450af59abd13b4dee7f1d6ba275f7e4b969e339e4ee611c886c1b9fed"},
- {file = "pytype-2020.11.12-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:0b3f38d1a4e10db09227905c4d8dfecd4d0923d8dc423c9ba6d95ea30f156964"},
- {file = "pytype-2020.11.12-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:3d1c2b684232e6e9e9a5f5e62080cc12f2fbd87a693a16eb67ad2c95d8671a05"},
- {file = "pytype-2020.11.12.tar.gz", hash = "sha256:66f694abee3eea5a1c7f8ca040324f8b13b1ed01e308ba898bf36775b0a4c944"},
-]
pytz = [
{file = "pytz-2020.4-py2.py3-none-any.whl", hash = "sha256:5c55e189b682d420be27c6995ba6edce0c0a77dd67bfbe2ae6607134d5851ffd"},
{file = "pytz-2020.4.tar.gz", hash = "sha256:3e6b7dd2d1e0a59084bcee14a17af60c5c562cdc16d828e8eba2e683d3a7e268"},
@@ -3054,9 +2970,50 @@ send2trash = [
{file = "Send2Trash-1.5.0-py3-none-any.whl", hash = "sha256:f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b"},
{file = "Send2Trash-1.5.0.tar.gz", hash = "sha256:60001cc07d707fe247c94f74ca6ac0d3255aabcb930529690897ca2a39db28b2"},
]
+sentencepiece = [
+ {file = "sentencepiece-0.1.95-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:21cfec2ec80eb6f603fb92b0416479272f3ec30cfd511b8525a964e2f1cf82a6"},
+ {file = "sentencepiece-0.1.95-cp35-cp35m-manylinux2014_aarch64.whl", hash = "sha256:f05663139279718421084d618131a24cffc068860873531ebfe38a73085cbd2e"},
+ {file = "sentencepiece-0.1.95-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:43acdb01466de8189b899de153b96eb50e0ea3b77608c1d4f4f8f0c6f343fe45"},
+ {file = "sentencepiece-0.1.95-cp35-cp35m-manylinux2014_ppc64le.whl", hash = "sha256:243ce7c067ba15e5883ab772117b144a8fa1f5827c466a664c9f52d173f6e375"},
+ {file = "sentencepiece-0.1.95-cp35-cp35m-manylinux2014_s390x.whl", hash = "sha256:8613286b537056e6d2029e306719e33d4e09c369a1741490e4e18f2a6a797996"},
+ {file = "sentencepiece-0.1.95-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:c510e0d26760d51b31f2fb05e1638419a1590df8783300d79e898f2bb93975a8"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-macosx_10_6_x86_64.whl", hash = "sha256:94f866601203b78095d9f219995820ff4606d67281895a6c79d5c1ffe75575ac"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:d789bcdce025b377a45830d3962d041b1acf7e416e5451bef081bd6a9c758dfd"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:53951098eddfc25a5fa0cd9be748c9346db3c2be0b6c74a8ac6663acbde2b639"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-manylinux2014_ppc64le.whl", hash = "sha256:f5b6ab735d30eb1801998d4c413592149f9414d9aa300d90a28e8769792d2a5b"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:60715ef703af2410e5f5cac89d8123f1a0a8dbce1406a2ceaecf805eb0c0cfd9"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:d880e8f70822fe98b4f584814f5cccebf9e72aea7b44acc1a26731780fac03f7"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-win32.whl", hash = "sha256:d89c04aeedab0d5c25de8fc6302d58ec6fb135e2670449376c7d0301d7963680"},
+ {file = "sentencepiece-0.1.95-cp36-cp36m-win_amd64.whl", hash = "sha256:8e2f6096899a32246a0c65ea7f24a01ff32ea49563ef013b348acb7bca5831d5"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-macosx_10_6_x86_64.whl", hash = "sha256:438ee23faf095a9ebcc97debad2b07c0647ff6a306ed4d430146c3f80c7f6354"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:48fd95e0bf082a432cff5d4b7e5fa6d5fdaf87fb2de210aa91f90086c89464a2"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:163d869ce8dd7a9ed11187756272e8c73cd1caae1f47a701e5d70ad80485a655"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:087373b148b82854a3c03a9ad57d58a8ff5366b2f6d718bca27f262c102439ce"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:fa8ee7411f31a7e7e1b4ed48de958e63befdba3465d7c7d9bd5a87235f7e5bd1"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:ad2866aebdf702b0d6a992b2b3b46c2de3739ca8a92bce17f24cf51c29fa4f3e"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-win32.whl", hash = "sha256:7f7929c7741ea276d44c1e7966a1347943fab2089a55bc32fc42ba3c71a6e2e1"},
+ {file = "sentencepiece-0.1.95-cp37-cp37m-win_amd64.whl", hash = "sha256:c2add7d87c30898661de5b9e492bd99c5b184c731dec3c7dd3d2c956e4003446"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-macosx_10_6_x86_64.whl", hash = "sha256:453f9cf531b5ea694472a5f0a4dc727bfb4f383c8a80a9b5261db6d3a59d4018"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:5ff761f322a1b34d691d8b1d87c735d8de725ce3458d879d9d0c319e285e7169"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:bc7324da0209b632be107123f40505e2400e6aa49e39b49a35d081c36e6cee1b"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:5e177f6e40b074e08d3c0c2a1a862fbc94897d9c3439c7752a03a4f61197a743"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:5cac1dcacc2c6bea397188daa549f194ca2bc4d0a7005633ecd03b165e1ad16f"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:77dff55aa8f74e36f7fd7df861723574630327fdfff0ca18fdbb4fe031c9ecbe"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-win32.whl", hash = "sha256:e3cf28e56f49edb9ac021e247399671b8099e516ecd8091ee8ad5d35716e16e3"},
+ {file = "sentencepiece-0.1.95-cp38-cp38-win_amd64.whl", hash = "sha256:6365bb9b7a17573e1ed9a277eafad6b5a489100840149297b2f399294ca11817"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-macosx_10_6_x86_64.whl", hash = "sha256:58a1013c2a676e16647c64505b9e8cd7e7e5fb9f2d92ec91f2d2a5f777632a69"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:10e8119175e35075d05dad49c2903903229c7b1331b872fff5ad6a85d369152c"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:99ba407001cc45b76e56e03f63eb27e011fe614c3a38e2c0ed5818bb88e050f6"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:fa52e8a438f500e07c81c068fe128f9c4e677331eff0b17b28c55585aa7c112a"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:26676ecc4985902cf4af5d597df3d2c4f32f58ed3e23db20c47950f6065089d7"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:03ac268fa1f5f2adcb083f40becd63b5bbbe2c13dec2cd46222688f8477827c5"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-win32.whl", hash = "sha256:4749b187c91e796fe52b82abef3c05a60d82065088844c0fe45d5c221ddc097a"},
+ {file = "sentencepiece-0.1.95-cp39-cp39-win_amd64.whl", hash = "sha256:d3410cffb275c319c61977ae3a8729ab224d330bdf69d66cf5f6c55a4deb3d9a"},
+ {file = "sentencepiece-0.1.95.tar.gz", hash = "sha256:8dd6e3e110f4c3f85a46e4a2ae1b6f7cf020b907fab50eac22beccf1680e0ea5"},
+]
sentry-sdk = [
- {file = "sentry-sdk-0.19.3.tar.gz", hash = "sha256:fd48f627945511c140546939b4d73815be4860cd1d2b9149577d7f6563e7bd60"},
- {file = "sentry_sdk-0.19.3-py2.py3-none-any.whl", hash = "sha256:81d7a5d8ca0b13a16666e8280127b004565aa988bfeec6481e98a8601804b215"},
+ {file = "sentry-sdk-0.19.4.tar.gz", hash = "sha256:1052f0ed084e532f66cb3e4ba617960d820152aee8b93fc6c05bd53861768c1c"},
+ {file = "sentry_sdk-0.19.4-py2.py3-none-any.whl", hash = "sha256:4c42910a55a6b1fe694d5e4790d5188d105d77b5a6346c1c64cbea8c06c0e8b7"},
]
shortuuid = [
{file = "shortuuid-1.0.1-py3-none-any.whl", hash = "sha256:492c7402ff91beb1342a5898bd61ea953985bf24a41cd9f247409aa2e03c8f77"},
@@ -3198,8 +3155,8 @@ tornado = [
{file = "tornado-6.1.tar.gz", hash = "sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791"},
]
tqdm = [
- {file = "tqdm-4.52.0-py2.py3-none-any.whl", hash = "sha256:80d9d5165d678dbd027dd102dfb99f71bf05f333b61fb761dbba13b4ab719ead"},
- {file = "tqdm-4.52.0.tar.gz", hash = "sha256:18d6a615aedd09ec8456d9524489dab330af4bd5c2a14a76eb3f9a0e14471afe"},
+ {file = "tqdm-4.53.0-py2.py3-none-any.whl", hash = "sha256:5ff3f5232b19fa4c5531641e480b7fad4598819f708a32eb815e6ea41c5fa313"},
+ {file = "tqdm-4.53.0.tar.gz", hash = "sha256:3d3f1470d26642e88bd3f73353cb6ff4c51ef7d5d7efef763238f4bc1f7e4e81"},
]
traitlets = [
{file = "traitlets-5.0.5-py3-none-any.whl", hash = "sha256:69ff3f9d5351f31a7ad80443c2674b7099df13cc41fc5fa6e2f6d3b0330b0426"},
@@ -3247,11 +3204,11 @@ urllib3 = [
{file = "urllib3-1.26.2.tar.gz", hash = "sha256:19188f96923873c92ccb987120ec4acaa12f0461fa9ce5d3d0772bc965a39e08"},
]
wandb = [
- {file = "wandb-0.9.7-py2.py3-none-any.whl", hash = "sha256:21d6f17c868c5de6b400c878962c1933f0574f1088f981b99f393cfeb80410b0"},
- {file = "wandb-0.9.7.tar.gz", hash = "sha256:b07a4cc7c317528273bd10ba903fd3fe851cab995d4ddaa7491b55e292f1c87d"},
+ {file = "wandb-0.10.12-py2.py3-none-any.whl", hash = "sha256:b3bf35840fd4048d85730e698f10b0fafb7bae05025ee2243793f30300e3f3d8"},
+ {file = "wandb-0.10.12.tar.gz", hash = "sha256:052dd5f59ab1a655a82253bc4603678fe06e0136be097dc1964f1a0c5bd64116"},
]
watchdog = [
- {file = "watchdog-0.10.3.tar.gz", hash = "sha256:4214e1379d128b0588021880ccaf40317ee156d4603ac388b9adcf29165e0c04"},
+ {file = "watchdog-0.10.4.tar.gz", hash = "sha256:e38bffc89b15bafe2a131f0e1c74924cf07dcec020c2e0a26cccd208831fcd43"},
]
wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
@@ -3273,7 +3230,3 @@ xdoctest = [
{file = "xdoctest-0.12.0-py2.py3-none-any.whl", hash = "sha256:82424d2cc4b6d6b96b7b7134c81e97a4594c536547c1954533128a6a26cf1cb2"},
{file = "xdoctest-0.12.0.tar.gz", hash = "sha256:2d985d8d78d4444079d3b072965327ab06a5e6dcb4882f3561d7596eb4da6b13"},
]
-zipp = [
- {file = "zipp-3.4.0-py3-none-any.whl", hash = "sha256:102c24ef8f171fd729d46599845e95c7ab894a4cf45f5de11a44cc7444fb1108"},
- {file = "zipp-3.4.0.tar.gz", hash = "sha256:ed5eee1974372595f9e416cc7bbeeb12335201d8081ca8a0743c954d4446e5cb"},
-]
diff --git a/pyproject.toml b/pyproject.toml
index c977270..4c674bc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,7 @@ repository = "https://github.com/aktersnurra/text-recognizer"
keywords = ["text recognizer, deep learning, pytorch"]
[tool.poetry.dependencies]
-python = "^3.7"
+python = "^3.8"
click = "^7.1.2"
flake8-annotations = "^2.1.0"
flake8-docstrings = "^1.5.0"
@@ -30,14 +30,16 @@ tqdm = "^4.46.1"
pytest = "^5.4.3"
opencv-python = "^4.3.0"
nltk = "^3.5"
-einops = "^0.2.0"
-wandb = "^0.9.6"
torch-summary = "^1.4.2"
python-Levenshtein = "^0.12.0"
defusedxml = "^0.6.0"
-pytorch-block-sparse = "^0.1.2"
pytorch-metric-learning = "^0.9.92"
omegaconf = "^2.0.2"
+jupyterthemes = "^0.20.0"
+wandb = "^0.10.12"
+einops = "^0.3.0"
+gtn = "^0.0.0"
+sentencepiece = "^0.1.95"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
@@ -59,7 +61,8 @@ sphinx = "^3.0.4"
jupyter = "^1.0.0"
gpustat = "^0.6.0"
redlock-py = "^1.0.8"
-wandb = "^0.9.4"
+wandb = "^0.10.11"
+graphviz = "^0.16"
[tool.coverage.report]
fail_under = 50
@@ -71,6 +74,7 @@ create-emnist-support-files = "text_recognizer.tests.support.create_emnist_suppo
create-emnist-lines-datasets = "text_recognizer.datasets.emnist_lines_dataset:create_datasets"
create-iam-paragraphs = "text_recognizer.datasets.iam_paragraphs_dataset:main"
prepare-experiments = "training.prepare_experiments:run_cli"
+run-experiment = "training.run_experiment:run_cli"
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index b5fdbe0..0e4b298 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -16,6 +16,7 @@
"import torch.nn.functional as F\n",
"import torch\n",
"from torch import nn\n",
+ "from torchsummary import summary\n",
"from importlib.util import find_spec\n",
"if find_spec(\"text_recognizer\") is None:\n",
" import sys\n",
@@ -24,73 +25,76 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks import CTCTransformer"
+ "from text_recognizer.networks import CNN"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
- "model = CTCTransformer(\n",
- " num_encoder_layers=2,\n",
- " hidden_dim=256,\n",
- " vocab_size=56,\n",
- " num_heads=8,\n",
- " adaptive_pool_dim=[None, 1],\n",
- " expansion_dim=2048,\n",
- " dropout_rate=0.1,\n",
- " max_len=256,\n",
- " patch_size=(28, 32),\n",
- " stride=(1, 28),\n",
- " activation=\"gelu\",\n",
- " backbone=\"WideResidualNetwork\",\n",
- "backbone_args={\n",
- " \"in_channels\": 1,\n",
- " \"in_planes\": 64,\n",
- " \"num_classes\": 80,\n",
- " \"depth\": 10,\n",
- " \"width_factor\": 1,\n",
- " \"dropout_rate\": 0.1,\n",
- " \"num_layers\": 4,\n",
- " \"num_stages\": [64, 128, 256, 256],\n",
- " \"activation\": \"elu\",\n",
- " \"use_decoder\": False,\n",
- "},\n",
- " )"
+ "cnn = CNN()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 79,
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "i = nn.Sequential(nn.Conv2d(1,1,1,1))"
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 81,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (1): Sequential(\n",
+ " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 81,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.Sequential(i,i)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 128, 1, 59])"
+ ]
+ },
+ "execution_count": 64,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "backbone: WideResidualNetwork\n",
- " backbone_args:\n",
- " in_channels: 1\n",
- " in_planes: 64\n",
- " num_classes: 80\n",
- " depth: 10\n",
- " width_factor: 1\n",
- " dropout_rate: 0.1\n",
- " num_layers: 4 \n",
- " num_stages: [64, 128, 256, 256]\n",
- " activation: elu\n",
- " use_decoder: false\n",
- " n"
+ "cnn(t).shape"
]
},
{
@@ -99,80 +103,236 @@
"metadata": {},
"outputs": [],
"source": [
+ "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"t = torch.randn(2, 1, 28, 952)"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, l = vqvae(t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([56, 952])"
+ "29.5"
]
},
- "execution_count": 3,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "t.view(-1, 952).shape"
+ "5 * 59 / 10"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([119, 2, 56])"
+ "torch.Size([2, 1, 28, 952])"
]
},
- "execution_count": 14,
+ "execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "model(t).shape"
+ "x.shape"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
- "ename": "RuntimeError",
- "evalue": "Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m----------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
- "\u001b[0;32m~/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/networks/ctc_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, trg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext_representation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"b t y -> t b y\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1691\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1692\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (238x128 and 256x56)",
- "\nThe above exception was the direct cause of the following exception:\n",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-8-85c5209ae40a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m952\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mexecuted_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlayer\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msummary_list\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuted\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m raise RuntimeError(\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\"Failed to run torchsummary. See above stack traces for more details. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\"Executed layers up to: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexecuted_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Encoder: 1-1 [-1, 32, 5, 59] --\n",
+ "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n",
+ "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n",
+ "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n",
+ "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n",
+ "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n",
+ "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n",
+ "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n",
+ "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n",
+ "├─Decoder: 1-2 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n",
+ "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n",
+ "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n",
+ "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n",
+ "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n",
+ "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n",
+ "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n",
+ "===============================================================================================\n",
+ "Total params: 4,343,905\n",
+ "Trainable params: 4,343,905\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 1.76\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 9.32\n",
+ "Params size (MB): 16.57\n",
+ "Estimated Total Size (MB): 26.00\n",
+ "===============================================================================================\n"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Encoder: 1-1 [-1, 32, 5, 59] --\n",
+ "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n",
+ "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n",
+ "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n",
+ "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n",
+ "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n",
+ "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n",
+ "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n",
+ "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n",
+ "├─Decoder: 1-2 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n",
+ "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n",
+ "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n",
+ "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n",
+ "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n",
+ "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n",
+ "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n",
+ "===============================================================================================\n",
+ "Total params: 4,343,905\n",
+ "Trainable params: 4,343,905\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 1.76\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 9.32\n",
+ "Params size (MB): 16.57\n",
+ "Estimated Total Size (MB): 26.00\n",
+ "==============================================================================================="
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 94,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "up = nn.Upsample([4, 59])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 107,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 32, 4, 59])"
+ ]
+ },
+ "execution_count": 107,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "up(tt).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 32, 1, 59])"
+ ]
+ },
+ "execution_count": 104,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "summary(model, (1, 28, 952), device=\"cpu\", depth=3)"
+ "tt.shape"
]
},
{
diff --git a/src/notebooks/02c-image-patches.ipynb b/src/notebooks/02c-image-patches.ipynb
index ee9a800..fedea91 100644
--- a/src/notebooks/02c-image-patches.ipynb
+++ b/src/notebooks/02c-image-patches.ipynb
@@ -48,8 +48,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-01-04 19:10:11.431 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_generate_data:159 - Generating data...\n",
- "2021-01-04 19:10:17.812 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:152 - EmnistLinesDataset loading data from HDF5...\n"
+ "2021-01-10 17:44:25.666 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:153 - EmnistLinesDataset loading data from HDF5...\n"
]
}
],
@@ -210,7 +209,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@@ -219,17 +218,17 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"from einops.layers.torch import Rearrange\n",
- "slide = nn.Sequential(nn.Unfold(kernel_size=(28, 64), stride=(1, 54)), Rearrange(\"b (c h w) t -> b t c h w\", h=28, w=64, c=1))"
+ "slide = nn.Sequential(nn.Unfold(kernel_size=(28, 46), stride=(1, 46)), Rearrange(\"b (c h w) t -> b t c h w\", h=28, w=46, c=1))"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@@ -238,7 +237,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@@ -247,17 +246,27 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 33,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1, 28, 952])"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "p=28\n",
- "x = rearrange(data, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)"
+ "data.shape"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
@@ -266,7 +275,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 35,
"metadata": {},
"outputs": [
{
@@ -275,7 +284,7 @@
"torch.Size([1, 34, 784])"
]
},
- "execution_count": 25,
+ "execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
@@ -286,7 +295,7 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -296,7 +305,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@@ -305,16 +314,16 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([17, 1, 28, 64])"
+ "torch.Size([20, 1, 28, 46])"
]
},
- "execution_count": 15,
+ "execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
@@ -325,14 +334,14 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 38,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 5 Axes>"
]
@@ -361,7 +370,19 @@
"cell_type": "code",
"execution_count": 18,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "ImportError",
+ "evalue": "cannot import name 'fetch_data_loaders' from 'text_recognizer.datasets.util' (/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/datasets/util.py)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-18-5d40384147e9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfetch_data_loaders\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mImportError\u001b[0m: cannot import name 'fetch_data_loaders' from 'text_recognizer.datasets.util' (/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/datasets/util.py)"
+ ]
+ }
+ ],
"source": [
"from text_recognizer.datasets.util import fetch_data_loaders"
]
diff --git a/src/notebooks/04a-look-at-iam-lines.ipynb b/src/notebooks/04a-look-at-iam-lines.ipynb
index 036604d..de59a85 100644
--- a/src/notebooks/04a-look-at-iam-lines.ipynb
+++ b/src/notebooks/04a-look-at-iam-lines.ipynb
@@ -33,19 +33,48 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
- "transform = [{\"type\": \"ToTensor\", \"args\": None}, \n",
- " {\"type\": \"ApplyContrast\", \"args\": {\"low\": 0.0, \"high\": 0.15}},\n",
+ "transform = [{\"type\": \"ToPILImage\", \"args\": None}, \n",
+ " #{\"type\": \"RandomResizeCrop\", \"args\": None}, \n",
+ " {\"type\": \"RandomRotation\", \"args\": {\"degrees\": 0.8, \"fill\": 0}}, \n",
+ " {\"type\": \"ColorJitter\", \"args\": {\"brightness\": 0.5, \"contrast\": 0.5, \"saturation\": 0.5, \"hue\": 0.5}}, \n",
+ " {\"type\": \"ToTensor\", \"args\": None}, \n",
+ " {\"type\": \"Normalize\", \"args\": {\"mean\": [0.912], \"std\": 0.168}},\n",
" #{\"type\": \"RandomAffine\", \"args\": {\"degrees\": [-0.25, 0.25], \"scale\": [0.98, 1.0]}}\n",
" ]"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 61,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'type': 'ToPILImage', 'args': None},\n",
+ " {'type': 'RandomRotation', 'args': {'degrees': 0.8, 'fill': 0}},\n",
+ " {'type': 'ColorJitter',\n",
+ " 'args': {'brightness': 0.5, 'contrast': 0.5, 'saturation': 0.5, 'hue': 0.5}},\n",
+ " {'type': 'ToTensor', 'args': None},\n",
+ " {'type': 'Normalize', 'args': {'mean': [0.912], 'std': 0.168}}]"
+ ]
+ },
+ "execution_count": 61,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "transform"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
"metadata": {},
"outputs": [
{
@@ -69,7 +98,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 63,
"metadata": {
"scrolled": true
},
@@ -80,7 +109,7 @@
"(28, 952)"
]
},
- "execution_count": 5,
+ "execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
@@ -91,7 +120,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 64,
"metadata": {},
"outputs": [
{
@@ -100,7 +129,7 @@
"(97, 54)"
]
},
- "execution_count": 6,
+ "execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
@@ -111,7 +140,16 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 65,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchvision.transforms import ToPILImage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
@@ -123,7 +161,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 69,
"metadata": {},
"outputs": [
{
@@ -144,7 +182,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -154,7 +192,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -164,7 +202,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -174,7 +212,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -184,7 +222,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -194,7 +232,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -204,7 +242,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -214,7 +252,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -224,7 +262,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -234,7 +272,7 @@
},
{
"data": {
- "image/png": "\n",
+ "image/png": "\n",
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
@@ -258,7 +296,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
@@ -267,218 +305,34 @@
},
{
"cell_type": "code",
- "execution_count": 41,
- "metadata": {},
- "outputs": [],
- "source": [
- "data1 = torch.stack((data, data))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 42,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1, 28, 952])"
- ]
- },
- "execution_count": 42,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data1.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [],
- "source": [
- "patches = sliding_window(data.unsqueeze(0), (28, 32), (1, 28))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 52,
- "metadata": {},
- "outputs": [],
- "source": [
- "patches = sliding_window(data1, (28, 32), (1, 28))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 53,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 33, 1, 28, 32])"
- ]
- },
- "execution_count": 53,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "patches.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 48,
- "metadata": {},
- "outputs": [],
- "source": [
- "patches = patches[1]"
- ]
- },
- {
- "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "from einops import rearrange"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 54,
- "metadata": {},
- "outputs": [],
- "source": [
- "p = rearrange(patches, \"b t c h w -> (b t) c h w\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 56,
- "metadata": {},
- "outputs": [],
- "source": [
- "patches = rearrange(p, \"(b t) c h w -> b c h (t w)\", b=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 57,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1, 28, 1056])"
- ]
- },
- "execution_count": 57,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "patches.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 58,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([66, 1, 28, 32])"
- ]
- },
- "execution_count": 58,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "p.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 28, 952])"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
"source": [
"data.shape"
]
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "p=7\n",
- "x = rearrange(data.unsqueeze(0), 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = 28, p2 = p)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 42,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 136, 196])"
- ]
- },
- "execution_count": 42,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "x.shape"
+ "patches = sliding_window(data.unsqueeze(0), (28, 46), (1, 46))"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "patches = rearrange(x, 'b t (h w) -> b t h w', h = 28, w = p)"
+ "patches.shape"
]
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -487,26 +341,13 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 1440x1440 with 12 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
"fig = plt.figure(figsize=(20, 20))\n",
- "for i in range(12):\n",
- " ax = fig.add_subplot(1, 12, i + 1)\n",
+ "for i in range(6):\n",
+ " ax = fig.add_subplot(1, 6, i + 1)\n",
" ax.imshow(patches[i].squeeze(0), cmap='gray')"
]
},
diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/06-try-transformer-model-predictions.ipynb
index d39e111..d39e111 100644
--- a/src/notebooks/Untitled.ipynb
+++ b/src/notebooks/06-try-transformer-model-predictions.ipynb
diff --git a/src/notebooks/07-look-at-lexicon.ipynb b/src/notebooks/07-look-at-lexicon.ipynb
new file mode 100644
index 0000000..b7a5a0e
--- /dev/null
+++ b/src/notebooks/07-look-at-lexicon.ipynb
@@ -0,0 +1,1119 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "from pathlib import Path\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch.nn.functional as F\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torchsummary import summary\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path = Path(\"../\").resolve().parent / \"data\" / \"processed\" / \"iam_lines\" / \"iamdb_1kwp_lex_1000.txt\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PosixPath('/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/processed/iam_lines/iamdb_1kwp_lex_1000.txt')"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(path, \"r\") as f:\n",
+ " lex = (line.strip().split() for line in f)\n",
+ " lex = {line[0]: line[1:] for line in lex}\n",
+ " #print(len(lex))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'!': ['▁', '!'],\n",
+ " '\"': ['▁', '\"'],\n",
+ " '&': ['▁', '&'],\n",
+ " \"'\": ['▁', \"'\"],\n",
+ " \"'30s\": ['▁', \"'\", '3', '0', 's'],\n",
+ " \"'61\": ['▁', \"'\", '6', '1'],\n",
+ " \"'d\": ['▁', \"'\", 'd'],\n",
+ " \"'ll\": ['▁', \"'\", 'll'],\n",
+ " \"'m\": ['▁', \"'\", 'm'],\n",
+ " \"'re\": ['▁', \"'\", 're'],\n",
+ " \"'s\": ['▁', \"'\", 's'],\n",
+ " \"'ve\": ['▁', \"'\", 've'],\n",
+ " '(': ['▁', '('],\n",
+ " ')': ['▁', ')'],\n",
+ " '*': ['▁', '*'],\n",
+ " '+2.8': ['▁', '+', '2', '.', '8'],\n",
+ " '+3.6': ['▁', '+', '3', '.', '6'],\n",
+ " ',': ['▁', ','],\n",
+ " '-': ['▁', '-'],\n",
+ " '-2.6': ['▁', '-', '2', '.', '6'],\n",
+ " '-5.4': ['▁', '-', '5', '.', '4'],\n",
+ " '.': ['▁', '.'],\n",
+ " '...': ['▁', '.', '.', '.'],\n",
+ " '0m': ['▁', '0', 'm'],\n",
+ " '1': ['▁', '1'],\n",
+ " '1,157': ['▁', '1', ',', '1', '5', '7'],\n",
+ " '1,400': ['▁', '1', ',', '4', '0', '0'],\n",
+ " '1,500': ['▁', '1', ',', '5', '0', '0'],\n",
+ " '1-2': ['▁', '1', '-', '2'],\n",
+ " '1.8': ['▁', '1', '.', '8'],\n",
+ " '1/2': ['▁', '1', '/', '2'],\n",
+ " '1/2-in.-long': ['▁', '1', '/', '2', '-', 'in', '.', '-', 'long'],\n",
+ " '1/4': ['▁', '1', '/', '4'],\n",
+ " '10': ['▁', '10'],\n",
+ " '10,000': ['▁', '10', ',', '0', '0', '0'],\n",
+ " '100': ['▁', '10', '0'],\n",
+ " '100,000,000': ['▁', '10', '0', ',', '0', '00,000'],\n",
+ " '104': ['▁', '10', '4'],\n",
+ " '11': ['▁', '1', '1'],\n",
+ " '12': ['▁', '1', '2'],\n",
+ " '12,000-word': ['▁', '1', '2', ',', '0', '0', '0', '-', 'word'],\n",
+ " '125': ['▁', '1', '2', '5'],\n",
+ " '13': ['▁', '1', '3'],\n",
+ " '13,000': ['▁', '1', '3', ',', '0', '0', '0'],\n",
+ " '14': ['▁', '1', '4'],\n",
+ " '15': ['▁', '1', '5'],\n",
+ " '15,000,000': ['▁', '1', '5', ',', '0', '00,000'],\n",
+ " '15-17': ['▁', '1', '5', '-', '1', '7'],\n",
+ " '15-nation': ['▁', '1', '5', '-', 'n', 'ation'],\n",
+ " '15-year-olds': ['▁', '1', '5', '-', 'year', '-', 'old', 's'],\n",
+ " '150,000,000': ['▁', '1', '5', '0', ',', '0', '00,000'],\n",
+ " '16': ['▁', '1', '6'],\n",
+ " '16,000': ['▁', '1', '6', ',', '0', '0', '0'],\n",
+ " '160': ['▁', '1', '6', '0'],\n",
+ " '163,000,000': ['▁', '1', '6', '3', ',', '0', '00,000'],\n",
+ " '167': ['▁', '1', '6', '7'],\n",
+ " '17': ['▁', '1', '7'],\n",
+ " '18': ['▁', '1', '8'],\n",
+ " '18.1': ['▁', '1', '8', '.', '1'],\n",
+ " '1830': ['▁', '1', '8', '3', '0'],\n",
+ " \"1830's\": ['▁', '1', '8', '3', '0', \"'\", 's'],\n",
+ " '1834': ['▁', '1', '8', '3', '4'],\n",
+ " '1897': ['▁', '1', '8', '9', '7'],\n",
+ " '19': ['▁', '1', '9'],\n",
+ " '19.5': ['▁', '1', '9', '.', '5'],\n",
+ " '1910': ['▁', '1', '9', '10'],\n",
+ " '1913': ['▁', '1', '9', '1', '3'],\n",
+ " '1914': ['▁', '1', '9', '1', '4'],\n",
+ " '1914-18': ['▁', '1', '9', '1', '4', '-', '1', '8'],\n",
+ " '1918': ['▁', '1', '9', '1', '8'],\n",
+ " '1920': ['▁', '1', '9', '2', '0'],\n",
+ " '1930': ['▁', '1', '9', '3', '0'],\n",
+ " '1931': ['▁', '1', '9', '3', '1'],\n",
+ " '1932': ['▁', '1', '9', '3', '2'],\n",
+ " '1934': ['▁', '1', '9', '3', '4'],\n",
+ " '1936': ['▁', '1', '9', '3', '6'],\n",
+ " '1939': ['▁', '1', '9', '3', '9'],\n",
+ " '1943': ['▁', '1', '9', '4', '3'],\n",
+ " '1944': ['▁', '1', '9', '4', '4'],\n",
+ " '1950': ['▁', '1', '9', '5', '0'],\n",
+ " '1951': ['▁', '1', '9', '5', '1'],\n",
+ " '1952': ['▁', '1', '9', '5', '2'],\n",
+ " '1953': ['▁', '1', '9', '5', '3'],\n",
+ " '1954': ['▁', '1', '9', '5', '4'],\n",
+ " '1956': ['▁', '1', '9', '5', '6'],\n",
+ " '1957': ['▁', '1', '9', '5', '7'],\n",
+ " '1958': ['▁', '1', '9', '5', '8'],\n",
+ " '1959': ['▁', '1', '9', '5', '9'],\n",
+ " '1960': ['▁', '1960'],\n",
+ " '1960s': ['▁', '1960', 's'],\n",
+ " '1961': ['▁', '1', '9', '6', '1'],\n",
+ " '1963': ['▁', '1', '9', '6', '3'],\n",
+ " '19th': ['▁', '1', '9', 'th'],\n",
+ " '1superceded': ['▁', '1', 'superceded'],\n",
+ " \"1tho'\": ['▁', '1', 'tho', \"'\"],\n",
+ " '2': ['▁', '2'],\n",
+ " '2,000': ['▁', '2', ',', '0', '0', '0'],\n",
+ " '2,415,000,000': ['▁', '2', ',', '4', '1', '5', ',', '0', '00,000'],\n",
+ " '20': ['▁', '2', '0'],\n",
+ " '20-month-old': ['▁', '2', '0', '-', 'month', '-', 'old'],\n",
+ " '200': ['▁', '2', '0', '0'],\n",
+ " '20th-century': ['▁', '2', '0', 'th', '-', 'cent', 'ur', 'y'],\n",
+ " '21': ['▁', '2', '1'],\n",
+ " '210million': ['▁', '2', '10', 'million'],\n",
+ " '22': ['▁', '2', '2'],\n",
+ " '23.1': ['▁', '2', '3', '.', '1'],\n",
+ " '24': ['▁', '2', '4'],\n",
+ " '24-strong': ['▁', '2', '4', '-', 'strong'],\n",
+ " '25': ['▁', '2', '5'],\n",
+ " '27': ['▁', '2', '7'],\n",
+ " '28.5': ['▁', '2', '8', '.', '5'],\n",
+ " '280,000': ['▁', '2', '8', '0', ',', '0', '0', '0'],\n",
+ " '287': ['▁', '2', '8', '7'],\n",
+ " '288': ['▁', '2', '8', '8'],\n",
+ " '2bhoys': ['▁', '2', 'b', 'ho', 'y', 's'],\n",
+ " '2ole': ['▁', '2', 'o', 'le'],\n",
+ " '2pianna': ['▁', '2', 'p', 'i', 'an', 'n', 'a'],\n",
+ " '2skint': ['▁', '2', 's', 'k', 'in', 't'],\n",
+ " '3': ['▁', '3'],\n",
+ " '3,000': ['▁', '3', ',', '0', '0', '0'],\n",
+ " '3.6': ['▁', '3', '.', '6'],\n",
+ " '3/0': ['▁', '3', '/', '0'],\n",
+ " '3/4': ['▁', '3', '/', '4'],\n",
+ " '30': ['▁', '3', '0'],\n",
+ " '30-day': ['▁', '3', '0', '-', 'day'],\n",
+ " '30-minute': ['▁', '3', '0', '-', 'minute'],\n",
+ " '300,000': ['▁', '3', '00,000'],\n",
+ " '32': ['▁', '3', '2'],\n",
+ " '33': ['▁', '3', '3'],\n",
+ " '34': ['▁', '3', '4'],\n",
+ " '35': ['▁', '3', '5'],\n",
+ " '357million': ['▁', '3', '5', '7', 'million'],\n",
+ " '36': ['▁', '3', '6'],\n",
+ " '37,000,000': ['▁', '3', '7', ',', '0', '00,000'],\n",
+ " '37.2': ['▁', '3', '7', '.', '2'],\n",
+ " '38': ['▁', '3', '8'],\n",
+ " '4': ['▁', '4'],\n",
+ " '4.8': ['▁', '4', '.', '8'],\n",
+ " '40': ['▁', '4', '0'],\n",
+ " '400': ['▁', '4', '0', '0'],\n",
+ " '400,000': ['▁', '4', '00,000'],\n",
+ " '420000': ['▁', '4', '2', '0', '0', '0', '0'],\n",
+ " '43': ['▁', '4', '3'],\n",
+ " '450': ['▁', '4', '5', '0'],\n",
+ " '5': ['▁', '5'],\n",
+ " '5,000': ['▁', '5', ',', '0', '0', '0'],\n",
+ " '5.30': ['▁', '5', '.', '3', '0'],\n",
+ " '5/8': ['▁', '5', '/', '8'],\n",
+ " '50': ['▁', '5', '0'],\n",
+ " '50,000': ['▁', '5', '0', ',', '0', '0', '0'],\n",
+ " '500': ['▁', '5', '0', '0'],\n",
+ " '53-year-old': ['▁', '5', '3', '-', 'year', '-', 'old'],\n",
+ " '55': ['▁', '5', '5'],\n",
+ " '550,000': ['▁', '5', '5', '0', ',', '0', '0', '0'],\n",
+ " '58': ['▁', '5', '8'],\n",
+ " '6': ['▁', '6'],\n",
+ " '6,000': ['▁', '6', ',', '0', '0', '0'],\n",
+ " '60': ['▁', '6', '0'],\n",
+ " '600': ['▁', '6', '0', '0'],\n",
+ " '600,000': ['▁', '6', '00,000'],\n",
+ " '61-year-old': ['▁', '6', '1', '-', 'year', '-', 'old'],\n",
+ " '68': ['▁', '6', '8'],\n",
+ " '6al': ['▁', '6', 'al'],\n",
+ " '6tic': ['▁', '6', 'tic'],\n",
+ " '7.30': ['▁', '7', '.', '3', '0'],\n",
+ " '7.42': ['▁', '7', '.', '4', '2'],\n",
+ " '70': ['▁', '7', '0'],\n",
+ " '70,000,000': ['▁', '7', '0', ',', '0', '00,000'],\n",
+ " '707': ['▁', '7', '0', '7'],\n",
+ " '73': ['▁', '7', '3'],\n",
+ " '750': ['▁', '7', '5', '0'],\n",
+ " '8': ['▁', '8'],\n",
+ " '8,000,000': ['▁', '8', ',', '0', '00,000'],\n",
+ " '8.25': ['▁', '8', '.', '2', '5'],\n",
+ " '8.4': ['▁', '8', '.', '4'],\n",
+ " '80': ['▁', '8', '0'],\n",
+ " '800': ['▁', '8', '0', '0'],\n",
+ " '800,000': ['▁', '8', '00,000'],\n",
+ " '86': ['▁', '8', '6'],\n",
+ " '88': ['▁', '8', '8'],\n",
+ " '88-year-old': ['▁', '8', '8', '-', 'year', '-', 'old'],\n",
+ " '89': ['▁', '8', '9'],\n",
+ " '89-year-old': ['▁', '8', '9', '-', 'year', '-', 'old'],\n",
+ " '9.30': ['▁', '9', '.', '3', '0'],\n",
+ " '9.40': ['▁', '9', '.', '4', '0'],\n",
+ " '90-day': ['▁', '9', '0', '-', 'day'],\n",
+ " '90-minute': ['▁', '9', '0', '-', 'minute'],\n",
+ " '91': ['▁', '9', '1'],\n",
+ " '950': ['▁', '9', '5', '0'],\n",
+ " '97.5': ['▁', '9', '7', '.', '5'],\n",
+ " ':': ['▁', ':'],\n",
+ " ';': ['▁', ';'],\n",
+ " '?': ['▁', '?'],\n",
+ " 'a': ['▁', 'a'],\n",
+ " 'abandon': ['▁', 'a', 'b', 'and', 'on'],\n",
+ " 'abandoned': ['▁', 'a', 'b', 'and', 'on', 'ed'],\n",
+ " 'abandoning': ['▁', 'a', 'b', 'and', 'on', 'ing'],\n",
+ " 'abashed': ['▁', 'a', 'bas', 'he', 'd'],\n",
+ " 'ability': ['▁', 'a', 'b', 'il', 'ity'],\n",
+ " 'able': ['▁', 'able'],\n",
+ " 'able-bodied': ['▁', 'able', '-', 'bo', 'die', 'd'],\n",
+ " 'abolish': ['▁', 'a', 'bo', 'l', 'ish'],\n",
+ " 'abolished': ['▁', 'a', 'bo', 'l', 'ish', 'ed'],\n",
+ " 'abolition': ['▁', 'a', 'bo', 'li', 'tion'],\n",
+ " 'abortion': ['▁', 'a', 'b', 'or', 'tion'],\n",
+ " 'abou': ['▁', 'a', 'bo', 'u'],\n",
+ " 'about': ['▁', 'about'],\n",
+ " 'about-': ['▁', 'about', '-'],\n",
+ " 'above': ['▁', 'a', 'bo', 've'],\n",
+ " 'abreast': ['▁', 'a', 'br', 'east'],\n",
+ " 'abroad': ['▁', 'a', 'b', 'ro', 'ad'],\n",
+ " 'absence': ['▁', 'a', 'b', 's', 'ence'],\n",
+ " 'absent': ['▁', 'a', 'b', 's', 'ent'],\n",
+ " 'absolutely': ['▁', 'a', 'b', 'solut', 'e', 'ly'],\n",
+ " 'abstraction': ['▁', 'a', 'b', 's', 'tr', 'action'],\n",
+ " 'abundance': ['▁', 'a', 'b', 'un', 'd', 'ance'],\n",
+ " 'ac-': ['▁', 'ac', '-'],\n",
+ " 'academic': ['▁', 'ac', 'a', 'de', 'm', 'ic'],\n",
+ " 'accent': ['▁', 'ac', 'cent'],\n",
+ " 'accents': ['▁', 'ac', 'cent', 's'],\n",
+ " 'accept': ['▁', 'accept'],\n",
+ " 'acceptable': ['▁', 'accept', 'able'],\n",
+ " 'accepted': ['▁', 'accept', 'ed'],\n",
+ " 'accepting': ['▁', 'accept', 'ing'],\n",
+ " 'accessories': ['▁', 'ac', 'ce', 's', 'so', 'ries'],\n",
+ " 'accident': ['▁', 'ac', 'c', 'id', 'ent'],\n",
+ " 'accidental': ['▁', 'ac', 'c', 'id', 'ent', 'al'],\n",
+ " 'accommodate': ['▁', 'ac', 'com', 'mo', 'date'],\n",
+ " 'accommodation': ['▁', 'ac', 'com', 'mo', 'd', 'ation'],\n",
+ " 'accompanied': ['▁', 'ac', 'com', 'pan', 'i', 'ed'],\n",
+ " 'accompanist': ['▁', 'ac', 'com', 'pan', 'is', 't'],\n",
+ " 'accompany': ['▁', 'ac', 'com', 'p', 'any'],\n",
+ " 'accomplished': ['▁', 'ac', 'com', 'p', 'l', 'ish', 'ed'],\n",
+ " 'accomplishments': ['▁', 'ac', 'com', 'p', 'l', 'ish', 'ment', 's'],\n",
+ " 'according': ['▁', 'ac', 'c', 'or', 'd', 'ing'],\n",
+ " 'account': ['▁', 'ac', 'count'],\n",
+ " 'accountancy': ['▁', 'ac', 'count', 'an', 'c', 'y'],\n",
+ " 'accra': ['▁', 'ac', 'c', 'ra'],\n",
+ " \"accra's\": ['▁', 'ac', 'c', 'ra', \"'\", 's'],\n",
+ " 'accuracy': ['▁', 'ac', 'cur', 'ac', 'y'],\n",
+ " 'accurate': ['▁', 'ac', 'cur', 'ate'],\n",
+ " 'accurately': ['▁', 'ac', 'cur', 'ate', 'ly'],\n",
+ " 'accused': ['▁', 'ac', 'c', 'used'],\n",
+ " 'achieved': ['▁', 'a', 'ch', 'i', 'e', 'v', 'ed'],\n",
+ " 'achievement': ['▁', 'a', 'ch', 'i', 'e', 've', 'ment'],\n",
+ " 'acquaintance': ['▁', 'ac', 'q', 'u', 'a', 'in', 't', 'ance'],\n",
+ " 'acquaintances': ['▁', 'ac', 'q', 'u', 'a', 'in', 't', 'ance', 's'],\n",
+ " 'acres': ['▁', 'ac', 're', 's'],\n",
+ " 'across': ['▁', 'a', 'cross'],\n",
+ " 'act': ['▁', 'act'],\n",
+ " 'acting': ['▁', 'act', 'ing'],\n",
+ " 'action': ['▁', 'action'],\n",
+ " 'actions': ['▁', 'action', 's'],\n",
+ " 'active': ['▁', 'act', 'ive'],\n",
+ " 'activists': ['▁', 'act', 'i', 'vi', 'st', 's'],\n",
+ " 'activities': ['▁', 'act', 'i', 'v', 'it', 'ies'],\n",
+ " 'activity': ['▁', 'act', 'i', 'v', 'ity'],\n",
+ " 'acton': ['▁', 'act', 'on'],\n",
+ " 'actor': ['▁', 'act', 'or'],\n",
+ " 'actress': ['▁', 'act', 're', 's', 's'],\n",
+ " 'acts': ['▁', 'act', 's'],\n",
+ " 'actual': ['▁', 'act', 'ual'],\n",
+ " 'actually': ['▁', 'act', 'ual', 'ly'],\n",
+ " 'adamafio': ['▁', 'ad', 'a', 'ma', 'f', 'i', 'o'],\n",
+ " 'adaptation': ['▁', 'ad', 'ap', 't', 'ation'],\n",
+ " 'adapted': ['▁', 'ad', 'ap', 'ted'],\n",
+ " 'adapting': ['▁', 'ad', 'ap', 't', 'ing'],\n",
+ " 'add': ['▁', 'ad', 'd'],\n",
+ " 'added': ['▁', 'ad', 'd', 'ed'],\n",
+ " 'adding': ['▁', 'adding'],\n",
+ " 'addition': ['▁', 'ad', 'd', 'it', 'ion'],\n",
+ " 'additions': ['▁', 'ad', 'd', 'it', 'ion', 's'],\n",
+ " 'address': ['▁', 'ad', 'dr', 'es', 's'],\n",
+ " 'addressed': ['▁', 'ad', 'dr', 'es', 's', 'ed'],\n",
+ " 'addresses': ['▁', 'ad', 'dr', 'es', 'se', 's'],\n",
+ " 'addressing': ['▁', 'ad', 'dr', 'es', 's', 'ing'],\n",
+ " 'adenauer': ['▁', 'adenauer'],\n",
+ " \"adenauer's\": ['▁', 'adenauer', \"'\", 's'],\n",
+ " 'adequate': ['▁', 'ad', 'equa', 'te'],\n",
+ " 'adhem': ['▁', 'ad', 'he', 'm'],\n",
+ " 'adjust': ['▁', 'ad', 'just'],\n",
+ " 'adjustment': ['▁', 'ad', 'just', 'ment'],\n",
+ " 'administration': ['▁', 'ad', 'ministr', 'ation'],\n",
+ " \"administration's\": ['▁', 'ad', 'ministr', 'ation', \"'\", 's'],\n",
+ " 'administrative': ['▁', 'ad', 'ministr', 'at', 'ive'],\n",
+ " 'admiralty': ['▁', 'ad', 'm', 'i', 'r', 'al', 'ty'],\n",
+ " 'admire': ['▁', 'ad', 'm', 'i', 're'],\n",
+ " 'admit': ['▁', 'ad', 'm', 'it'],\n",
+ " 'admitted': ['▁', 'ad', 'm', 'it', 'ted'],\n",
+ " 'admitting': ['▁', 'ad', 'm', 'it', 't', 'ing'],\n",
+ " 'adopted': ['▁', 'a', 'do', 'p', 'ted'],\n",
+ " 'adopting': ['▁', 'a', 'do', 'p', 't', 'ing'],\n",
+ " 'adoption': ['▁', 'a', 'do', 'p', 'tion'],\n",
+ " 'adult': ['▁', 'ad', 'ul', 't'],\n",
+ " 'advance': ['▁', 'ad', 'v', 'ance'],\n",
+ " 'advanced': ['▁', 'ad', 'v', 'ance', 'd'],\n",
+ " 'advancing': ['▁', 'ad', 'v', 'an', 'c', 'ing'],\n",
+ " 'advantage': ['▁', 'advantage'],\n",
+ " 'advantages': ['▁', 'advantage', 's'],\n",
+ " 'advertisement': ['▁', 'ad', 'ver', 't', 'is', 'e', 'ment'],\n",
+ " 'advertisements': ['▁', 'ad', 'ver', 't', 'is', 'ements'],\n",
+ " 'advice': ['▁', 'advi', 'ce'],\n",
+ " 'advisability': ['▁', 'advi', 's', 'a', 'b', 'il', 'ity'],\n",
+ " 'advise': ['▁', 'advise'],\n",
+ " 'advised': ['▁', 'advise', 'd'],\n",
+ " 'advisers': ['▁', 'advise', 'r', 's'],\n",
+ " 'advocate': ['▁', 'ad', 'v', 'o', 'c', 'ate'],\n",
+ " 'af-': ['▁', 'a', 'f', '-'],\n",
+ " 'affairs': ['▁', 'a', 'f', 'f', 'air', 's'],\n",
+ " 'affected': ['▁', 'a', 'f', 'fe', 'c', 'ted'],\n",
+ " 'affection': ['▁', 'a', 'f', 'fe', 'c', 'tion'],\n",
+ " 'affilia-': ['▁', 'a', 'f', 'f', 'il', 'i', 'a', '-'],\n",
+ " 'affiliations': ['▁', 'a', 'f', 'f', 'il', 'i', 'ation', 's'],\n",
+ " 'affluence': ['▁', 'a', 'f', 'f', 'l', 'u', 'ence'],\n",
+ " 'affluent': ['▁', 'a', 'f', 'f', 'l', 'u', 'ent'],\n",
+ " 'afford': ['▁', 'a', 'f', 'for', 'd'],\n",
+ " 'afraid': ['▁', 'a', 'fr', 'a', 'id'],\n",
+ " 'africa': ['▁', 'africa'],\n",
+ " \"africa's\": ['▁', 'africa', \"'\", 's'],\n",
+ " 'african': ['▁', 'african'],\n",
+ " 'africans': ['▁', 'african', 's'],\n",
+ " 'after': ['▁', 'after'],\n",
+ " 'afternoon': ['▁', 'after', 'no', 'on'],\n",
+ " 'afterwards': ['▁', 'after', 'ward', 's'],\n",
+ " 'again': ['▁', 'again'],\n",
+ " 'against': ['▁', 'against'],\n",
+ " 'age': ['▁', 'age'],\n",
+ " 'age-structure': ['▁', 'age', '-', 's', 'tru', 'c', 'ture'],\n",
+ " 'aged': ['▁', 'aged'],\n",
+ " 'ageing': ['▁', 'age', 'ing'],\n",
+ " 'agent': ['▁', 'a', 'g', 'ent'],\n",
+ " 'agents': ['▁', 'a', 'g', 'ent', 's'],\n",
+ " 'ages': ['▁', 'age', 's'],\n",
+ " 'agitation': ['▁', 'a', 'g', 'it', 'ation'],\n",
+ " 'ago': ['▁', 'a', 'go'],\n",
+ " 'agree': ['▁', 'agree'],\n",
+ " 'agreed': ['▁', 'agree', 'd'],\n",
+ " 'agreement': ['▁', 'agree', 'ment'],\n",
+ " 'agreements': ['▁', 'agree', 'ment', 's'],\n",
+ " 'agriculture': ['▁', 'a', 'gr', 'ic', 'ul', 'ture'],\n",
+ " 'ahead': ['▁', 'a', 'head'],\n",
+ " 'aid': ['▁', 'a', 'id'],\n",
+ " 'aide': ['▁', 'a', 'i', 'de'],\n",
+ " 'aided': ['▁', 'a', 'id', 'ed'],\n",
+ " 'aides': ['▁', 'a', 'id', 'es'],\n",
+ " 'aim': ['▁', 'a', 'im'],\n",
+ " 'aimed': ['▁', 'a', 'im', 'ed'],\n",
+ " 'aiming': ['▁', 'a', 'im', 'ing'],\n",
+ " 'air': ['▁', 'air'],\n",
+ " 'aircraft': ['▁', 'air', 'craft'],\n",
+ " 'aired': ['▁', 'air', 'ed'],\n",
+ " \"airliner's\": ['▁', 'air', 'line', 'r', \"'\", 's'],\n",
+ " 'airmen': ['▁', 'air', 'men'],\n",
+ " 'airport': ['▁', 'air', 'port'],\n",
+ " 'akin': ['▁', 'a', 'k', 'in'],\n",
+ " \"aladdin's\": ['▁', 'al', 'ad', 'd', 'in', \"'\", 's'],\n",
+ " 'alan': ['▁', 'al', 'an'],\n",
+ " 'alarm': ['▁', 'al', 'arm'],\n",
+ " 'alarmed': ['▁', 'al', 'arm', 'ed'],\n",
+ " 'alas': ['▁', 'al', 'as'],\n",
+ " 'alcoholic': ['▁', 'al', 'co', 'ho', 'li', 'c'],\n",
+ " 'algeria': ['▁', 'al', 'g', 'er', 'i', 'a'],\n",
+ " 'alike': ['▁', 'a', 'like'],\n",
+ " 'alive': ['▁', 'a', 'live'],\n",
+ " 'all': ['▁', 'all'],\n",
+ " 'all-regular': ['▁', 'all', '-', 'regular'],\n",
+ " 'alleged': ['▁', 'al', 'leg', 'ed'],\n",
+ " 'allen': ['▁', 'all', 'en'],\n",
+ " 'alleviation': ['▁', 'alleviation'],\n",
+ " 'alley': ['▁', 'al', 'le', 'y'],\n",
+ " 'alliance': ['▁', 'all', 'i', 'ance'],\n",
+ " 'alliances': ['▁', 'all', 'i', 'ance', 's'],\n",
+ " 'allied': ['▁', 'all', 'i', 'ed'],\n",
+ " 'allies': ['▁', 'all', 'ies'],\n",
+ " 'allow': ['▁', 'allow'],\n",
+ " 'allowance': ['▁', 'allow', 'ance'],\n",
+ " 'allowances': ['▁', 'allow', 'ance', 's'],\n",
+ " 'allowed': ['▁', 'allow', 'ed'],\n",
+ " 'allowing': ['▁', 'allow', 'ing'],\n",
+ " 'ally': ['▁', 'al', 'ly'],\n",
+ " 'almost': ['▁', 'al', 'most'],\n",
+ " 'alone': ['▁', 'al', 'one'],\n",
+ " 'along': ['▁', 'a', 'long'],\n",
+ " 'alongside': ['▁', 'a', 'long', 'side'],\n",
+ " 'aloud': ['▁', 'a', 'lo', 'ud'],\n",
+ " 'already': ['▁', 'al', 'read', 'y'],\n",
+ " 'also': ['▁', 'also'],\n",
+ " 'alter': ['▁', 'al', 'ter'],\n",
+ " 'alternative': ['▁', 'al', 'ter', 'n', 'at', 'ive'],\n",
+ " 'alternatively': ['▁', 'al', 'ter', 'n', 'at', 'ive', 'ly'],\n",
+ " 'alternatives': ['▁', 'al', 'ter', 'n', 'at', 'ive', 's'],\n",
+ " 'although': ['▁', 'al', 'though'],\n",
+ " 'altogether': ['▁', 'al', 'together'],\n",
+ " 'altos': ['▁', 'al', 'to', 's'],\n",
+ " 'always': ['▁', 'always'],\n",
+ " 'am': ['▁', 'am'],\n",
+ " 'amateur': ['▁', 'am', 'ate', 'ur'],\n",
+ " 'amazed': ['▁', 'a', 'ma', 'z', 'ed'],\n",
+ " 'amazing': ['▁', 'a', 'ma', 'z', 'ing'],\n",
+ " 'ambassador': ['▁', 'am', 'bas', 's', 'ad', 'or'],\n",
+ " 'amber': ['▁', 'a', 'mber'],\n",
+ " 'ambition': ['▁', 'am', 'b', 'it', 'ion'],\n",
+ " 'ambitious': ['▁', 'am', 'b', 'it', 'i', 'ous'],\n",
+ " 'ambulance': ['▁', 'am', 'b', 'ul', 'ance'],\n",
+ " 'ambulances': ['▁', 'am', 'b', 'ul', 'ance', 's'],\n",
+ " 'america': ['▁', 'america'],\n",
+ " \"america's\": ['▁', 'america', \"'\", 's'],\n",
+ " 'american': ['▁', 'american'],\n",
+ " 'american-born': ['▁', 'american', '-', 'b', 'or', 'n'],\n",
+ " 'americans': ['▁', 'american', 's'],\n",
+ " 'amid': ['▁', 'am', 'id'],\n",
+ " 'ammunition': ['▁', 'am', 'm', 'un', 'it', 'ion'],\n",
+ " 'among': ['▁', 'among'],\n",
+ " 'amount': ['▁', 'a', 'mo', 'un', 't'],\n",
+ " 'ample': ['▁', 'amp', 'le'],\n",
+ " 'amusement': ['▁', 'am', 'use', 'ment'],\n",
+ " 'amusing': ['▁', 'am', 'us', 'ing'],\n",
+ " 'an': ['▁', 'an'],\n",
+ " 'analogy': ['▁', 'an', 'a', 'lo', 'g', 'y'],\n",
+ " 'analysed': ['▁', 'an', 'a', 'ly', 's', 'ed'],\n",
+ " 'anchor': ['▁', 'an', 'ch', 'or'],\n",
+ " 'ancient': ['▁', 'an', 'c', 'i', 'ent'],\n",
+ " 'and': ['▁', 'and'],\n",
+ " 'andrei': ['▁', 'and', 're', 'i'],\n",
+ " 'andrew': ['▁', 'and', 're', 'w'],\n",
+ " 'anecdotal': ['▁', 'an', 'e', 'c', 'do', 't', 'al'],\n",
+ " 'angel': ['▁', 'ang', 'el'],\n",
+ " 'angeles': ['▁', 'ang', 'el', 'es'],\n",
+ " 'angelo': ['▁', 'ang', 'e', 'lo'],\n",
+ " 'anger': ['▁', 'ang', 'er'],\n",
+ " 'anglais': ['▁', 'ang', 'la', 'is'],\n",
+ " 'angle': ['▁', 'ang', 'le'],\n",
+ " 'anglesey': ['▁', 'anglesey'],\n",
+ " \"anglesey's\": ['▁', 'anglesey', \"'\", 's'],\n",
+ " 'anglesey-road': ['▁', 'anglesey', '-', 'ro', 'ad'],\n",
+ " 'angola': ['▁', 'an', 'go', 'la'],\n",
+ " 'angrily': ['▁', 'an', 'gr', 'i', 'ly'],\n",
+ " 'angry': ['▁', 'ang', 'ry'],\n",
+ " 'ann': ['▁', 'an', 'n'],\n",
+ " 'anna': ['▁', 'an', 'n', 'a'],\n",
+ " 'announced': ['▁', 'an', 'no', 'un', 'c', 'ed'],\n",
+ " 'announcement': ['▁', 'an', 'no', 'un', 'ce', 'ment'],\n",
+ " 'announcing': ['▁', 'an', 'no', 'un', 'c', 'ing'],\n",
+ " 'annoyed': ['▁', 'an', 'no', 'y', 'ed'],\n",
+ " 'annual': ['▁', 'an', 'n', 'ual'],\n",
+ " 'another': ['▁', 'another'],\n",
+ " 'answer': ['▁', 'answer'],\n",
+ " 'answered': ['▁', 'answer', 'ed'],\n",
+ " 'answering': ['▁', 'answer', 'ing'],\n",
+ " 'antagonism': ['▁', 'ant', 'a', 'g', 'on', 'is', 'm'],\n",
+ " 'anthony': ['▁', 'an', 'th', 'on', 'y'],\n",
+ " 'anti-apartheid': ['▁', 'ant', 'i', '-', 'a', 'part', 'he', 'id'],\n",
+ " 'anti-bomb': ['▁', 'ant', 'i', '-', 'bomb'],\n",
+ " 'anti-german': ['▁', 'ant', 'i', '-', 'german'],\n",
+ " 'anti-nato': ['▁', 'ant', 'i', '-', 'nato'],\n",
+ " 'anti-negro': ['▁', 'ant', 'i', '-', 'negro'],\n",
+ " 'anti-nuclear': ['▁', 'ant', 'i', '-', 'nuclear'],\n",
+ " 'anti-soviet': ['▁', 'ant', 'i', '-', 'soviet'],\n",
+ " 'anti-tory': ['▁', 'ant', 'i', '-', 'tory'],\n",
+ " 'anticipation': ['▁', 'an', 'tic', 'ip', 'ation'],\n",
+ " 'antonioni': ['▁', 'ant', 'on', 'ion', 'i'],\n",
+ " \"antonioni's\": ['▁', 'ant', 'on', 'ion', 'i', \"'\", 's'],\n",
+ " 'any': ['▁', 'any'],\n",
+ " 'any-': ['▁', 'any', '-'],\n",
+ " 'anybody': ['▁', 'any', 'body'],\n",
+ " \"anybody's\": ['▁', 'any', 'body', \"'\", 's'],\n",
+ " 'anyone': ['▁', 'any', 'one'],\n",
+ " 'anything': ['▁', 'any', 'thing'],\n",
+ " 'anyway': ['▁', 'any', 'way'],\n",
+ " 'apart': ['▁', 'a', 'part'],\n",
+ " 'apartheid': ['▁', 'a', 'part', 'he', 'id'],\n",
+ " 'apathetic': ['▁', 'a', 'pa', 'the', 'tic'],\n",
+ " 'apathy': ['▁', 'a', 'pa', 'th', 'y'],\n",
+ " 'apex': ['▁', 'ap', 'ex'],\n",
+ " 'apocalypse': ['▁', 'a', 'po', 'c', 'a', 'ly', 'p', 'se'],\n",
+ " 'apologising': ['▁', 'a', 'po', 'lo', 'g', 'is', 'ing'],\n",
+ " 'appalled': ['▁', 'app', 'all', 'ed'],\n",
+ " 'appalling': ['▁', 'app', 'all', 'ing'],\n",
+ " 'apparatus': ['▁', 'app', 'ar', 'at', 'us'],\n",
+ " 'apparent': ['▁', 'app', 'ar', 'ent'],\n",
+ " 'apparently': ['▁', 'app', 'ar', 'ent', 'ly'],\n",
+ " 'appeal': ['▁', 'appeal'],\n",
+ " 'appealing': ['▁', 'appeal', 'ing'],\n",
+ " 'appeals': ['▁', 'appeal', 's'],\n",
+ " 'appear': ['▁', 'appear'],\n",
+ " 'appearance': ['▁', 'appear', 'ance'],\n",
+ " 'appeared': ['▁', 'appear', 'ed'],\n",
+ " 'appears': ['▁', 'appear', 's'],\n",
+ " 'appeasement': ['▁', 'app', 'e', 'a', 'se', 'ment'],\n",
+ " 'applauding': ['▁', 'app', 'la', 'ud', 'ing'],\n",
+ " 'appliances': ['▁', 'app', 'li', 'ance', 's'],\n",
+ " 'application': ['▁', 'app', 'li', 'c', 'ation'],\n",
+ " 'applications': ['▁', 'app', 'li', 'c', 'ation', 's'],\n",
+ " 'applied': ['▁', 'app', 'li', 'ed'],\n",
+ " 'apply': ['▁', 'app', 'ly'],\n",
+ " 'appointed': ['▁', 'ap', 'point', 'ed'],\n",
+ " 'appointment': ['▁', 'ap', 'point', 'ment'],\n",
+ " 'appreciable': ['▁', 'app', 're', 'c', 'i', 'able'],\n",
+ " 'appreciably': ['▁', 'app', 're', 'c', 'i', 'ably'],\n",
+ " 'appreciated': ['▁', 'app', 're', 'c', 'i', 'at', 'ed'],\n",
+ " 'appreciation': ['▁', 'app', 're', 'c', 'i', 'ation'],\n",
+ " 'apprenticeships': ['▁', 'app', 'r', 'ent', 'i', 'ce', 'ship', 's'],\n",
+ " 'approach': ['▁', 'ap', 'pro', 'a', 'ch'],\n",
+ " 'approached': ['▁', 'ap', 'pro', 'a', 'ch', 'ed'],\n",
+ " 'approaches': ['▁', 'ap', 'pro', 'a', 'che', 's'],\n",
+ " 'appropriate': ['▁', 'ap', 'pro', 'pri', 'ate'],\n",
+ " 'appropriated': ['▁', 'ap', 'pro', 'pri', 'at', 'ed'],\n",
+ " 'approval': ['▁', 'ap', 'pro', 'val'],\n",
+ " 'approximately': ['▁', 'ap', 'pro', 'x', 'im', 'ate', 'ly'],\n",
+ " 'april': ['▁', 'a', 'pri', 'l'],\n",
+ " 'archbishop': ['▁', 'ar', 'ch', 'b', 'is', 'hop'],\n",
+ " 'arches': ['▁', 'ar', 'che', 's'],\n",
+ " 'archipelago': ['▁', 'ar', 'ch', 'i', 'pe', 'la', 'go'],\n",
+ " 'architect': ['▁', 'ar', 'ch', 'it', 'e', 'c', 't'],\n",
+ " 'architecture': ['▁', 'ar', 'ch', 'it', 'e', 'c', 'ture'],\n",
+ " 'are': ['▁', 'are'],\n",
+ " 'area': ['▁', 'are', 'a'],\n",
+ " 'areas': ['▁', 'are', 'as'],\n",
+ " \"aren't\": ['▁', 'are', 'n', \"'\", 't'],\n",
+ " 'arguably': ['▁', 'ar', 'gu', 'ably'],\n",
+ " 'argued': ['▁', 'ar', 'gu', 'ed'],\n",
+ " 'argues': ['▁', 'ar', 'gu', 'es'],\n",
+ " 'arguing': ['▁', 'ar', 'gu', 'ing'],\n",
+ " 'argument': ['▁', 'ar', 'gu', 'ment'],\n",
+ " 'arguments': ['▁', 'ar', 'gu', 'ment', 's'],\n",
+ " 'arise': ['▁', 'a', 'rise'],\n",
+ " 'arises': ['▁', 'a', 'rise', 's'],\n",
+ " 'arm': ['▁', 'arm'],\n",
+ " 'armament': ['▁', 'arm', 'a', 'ment'],\n",
+ " 'armaments': ['▁', 'arm', 'a', 'ment', 's'],\n",
+ " 'armed': ['▁', 'arm', 'ed'],\n",
+ " 'armoured': ['▁', 'arm', 'our', 'ed'],\n",
+ " 'arms': ['▁', 'arm', 's'],\n",
+ " \"arms'\": ['▁', 'arm', 's', \"'\"],\n",
+ " 'army': ['▁', 'arm', 'y'],\n",
+ " 'arnold': ['▁', 'ar', 'n', 'old'],\n",
+ " 'arose': ['▁', 'a', 'ro', 'se'],\n",
+ " 'around': ['▁', 'a', 'round'],\n",
+ " 'aroused': ['▁', 'ar', 'ous', 'ed'],\n",
+ " 'arrange': ['▁', 'ar', 'range'],\n",
+ " 'arranged': ['▁', 'ar', 'range', 'd'],\n",
+ " 'arrangement': ['▁', 'ar', 'range', 'ment'],\n",
+ " 'arrangements': ['▁', 'ar', 'range', 'ment', 's'],\n",
+ " 'arranging': ['▁', 'ar', 'r', 'ang', 'ing'],\n",
+ " 'arrears': ['▁', 'ar', 're', 'ar', 's'],\n",
+ " 'arrested': ['▁', 'ar', 'rest', 'ed'],\n",
+ " 'arrival': ['▁', 'ar', 'r', 'i', 'val'],\n",
+ " 'arrive': ['▁', 'ar', 'r', 'ive'],\n",
+ " 'arrived': ['▁', 'arrived'],\n",
+ " 'arrives': ['▁', 'ar', 'r', 'ive', 's'],\n",
+ " 'arrogant': ['▁', 'ar', 'ro', 'g', 'ant'],\n",
+ " 'art': ['▁', 'ar', 't'],\n",
+ " 'arthur': ['▁', 'ar', 'th', 'ur'],\n",
+ " 'article': ['▁', 'ar', 'tic', 'le'],\n",
+ " 'articles': ['▁', 'ar', 'tic', 'le', 's'],\n",
+ " 'articulation': ['▁', 'ar', 'tic', 'ul', 'ation'],\n",
+ " 'artistic': ['▁', 'ar', 'tist', 'ic'],\n",
+ " 'artistically': ['▁', 'ar', 'tist', 'ical', 'ly'],\n",
+ " 'artistry': ['▁', 'ar', 'tist', 'ry'],\n",
+ " 'artists': ['▁', 'ar', 'tist', 's'],\n",
+ " 'as': ['▁', 'as'],\n",
+ " 'ascents': ['▁', 'as', 'cent', 's'],\n",
+ " 'ash': ['▁', 'as', 'h'],\n",
+ " 'ashen': ['▁', 'as', 'he', 'n'],\n",
+ " 'ask': ['▁', 'as', 'k'],\n",
+ " 'asked': ['▁', 'asked'],\n",
+ " 'asking': ['▁', 'asking'],\n",
+ " 'aspect': ['▁', 'a', 'spect'],\n",
+ " 'aspects': ['▁', 'a', 'spect', 's'],\n",
+ " 'aspiring': ['▁', 'as', 'p', 'i', 'r', 'ing'],\n",
+ " 'assault': ['▁', 'as', 's', 'a', 'ul', 't'],\n",
+ " 'assembler': ['▁', 'as', 'se', 'm', 'bl', 'er'],\n",
+ " 'assembly': ['▁', 'as', 'se', 'm', 'b', 'ly'],\n",
+ " 'assess': ['▁', 'as', 'se', 's', 's'],\n",
+ " 'assessment': ['▁', 'as', 'se', 's', 's', 'ment'],\n",
+ " 'assistance': ['▁', 'as', 's', 'istance'],\n",
+ " 'assistant': ['▁', 'as', 's', 'is', 't', 'ant'],\n",
+ " 'assistants': ['▁', 'as', 's', 'is', 't', 'ant', 's'],\n",
+ " 'associate': ['▁', 'associat', 'e'],\n",
+ " 'associated': ['▁', 'associat', 'ed'],\n",
+ " 'associates': ['▁', 'associat', 'es'],\n",
+ " 'association': ['▁', 'associat', 'ion'],\n",
+ " 'assortment': ['▁', 'as', 's', 'or', 't', 'ment'],\n",
+ " 'assumption': ['▁', 'assumption'],\n",
+ " 'assurance': ['▁', 'as', 's', 'ur', 'ance'],\n",
+ " 'astronaut': ['▁', 'as', 'tr', 'on', 'a', 'u', 't'],\n",
+ " 'astute': ['▁', 'a', 'st', 'u', 'te'],\n",
+ " 'at': ['▁', 'at'],\n",
+ " 'ately': ['▁', 'ate', 'ly'],\n",
+ " 'atkinson': ['▁', 'at', 'k', 'in', 's', 'on'],\n",
+ " 'atlantic': ['▁', 'at', 'l', 'an', 'tic'],\n",
+ " 'atmosphere': ['▁', 'atmospher', 'e'],\n",
+ " 'atmospheric': ['▁', 'atmospher', 'ic'],\n",
+ " 'atomic': ['▁', 'a', 'to', 'm', 'ic'],\n",
+ " 'atoms': ['▁', 'a', 'to', 'm', 's'],\n",
+ " 'attach': ['▁', 'at', 't', 'a', 'ch'],\n",
+ " 'attached': ['▁', 'at', 't', 'a', 'ch', 'ed'],\n",
+ " 'attack': ['▁', 'at', 't', 'a', 'ck'],\n",
+ " 'attacked': ['▁', 'at', 't', 'a', 'ck', 'ed'],\n",
+ " 'attacks': ['▁', 'at', 't', 'a', 'ck', 's'],\n",
+ " 'attainable': ['▁', 'at', 'tain', 'able'],\n",
+ " 'attempt': ['▁', 'attempt'],\n",
+ " 'attempted': ['▁', 'attempt', 'ed'],\n",
+ " 'attempting': ['▁', 'attempt', 'ing'],\n",
+ " 'attempts': ['▁', 'attempt', 's'],\n",
+ " 'atten-': ['▁', 'at', 'ten', '-'],\n",
+ " 'attend': ['▁', 'at', 't', 'end'],\n",
+ " 'attendance': ['▁', 'at', 't', 'end', 'ance'],\n",
+ " 'attended': ['▁', 'at', 't', 'end', 'ed'],\n",
+ " 'attending': ['▁', 'at', 't', 'end', 'ing'],\n",
+ " 'attention': ['▁', 'at', 'ten', 'tion'],\n",
+ " 'attitude': ['▁', 'at', 't', 'it', 'u', 'de'],\n",
+ " 'attitudes': ['▁', 'at', 't', 'it', 'ud', 'es'],\n",
+ " 'attracted': ['▁', 'at', 'tr', 'act', 'ed'],\n",
+ " 'attractive': ['▁', 'at', 'tr', 'act', 'ive'],\n",
+ " 'aubrey': ['▁', 'a', 'u', 'b', 're', 'y'],\n",
+ " 'audacity': ['▁', 'a', 'ud', 'ac', 'ity'],\n",
+ " 'auden': ['▁', 'a', 'ud', 'en'],\n",
+ " 'audience': ['▁', 'a', 'ud', 'i', 'ence'],\n",
+ " 'audio-tv': ['▁', 'a', 'ud', 'i', 'o', '-', 't', 'v'],\n",
+ " 'audited': ['▁', 'a', 'ud', 'it', 'ed'],\n",
+ " 'august': ['▁', 'a', 'ug', 'u', 'st'],\n",
+ " 'auntie': ['▁', 'a', 'un', 't', 'i', 'e'],\n",
+ " 'austerity': ['▁', 'a', 'u', 'ster', 'ity'],\n",
+ " 'australia': ['▁', 'a', 'us', 'tr', 'al', 'i', 'a'],\n",
+ " 'austria': ['▁', 'a', 'us', 'tri', 'a'],\n",
+ " 'austrian': ['▁', 'a', 'us', 'tri', 'an'],\n",
+ " 'authentic': ['▁', 'a', 'u', 'then', 'tic'],\n",
+ " 'author': ['▁', 'author'],\n",
+ " 'authorised': ['▁', 'author', 'is', 'ed'],\n",
+ " 'authorities': ['▁', 'author', 'it', 'ies'],\n",
+ " 'authority': ['▁', 'author', 'ity'],\n",
+ " 'automatically': ['▁', 'a', 'u', 'to', 'm', 'at', 'ical', 'ly'],\n",
+ " 'automation': ['▁', 'a', 'u', 'to', 'm', 'ation'],\n",
+ " 'autumn': ['▁', 'a', 'u', 't', 'um', 'n'],\n",
+ " 'available': ['▁', 'a', 'v', 'a', 'il', 'able'],\n",
+ " 'avenue': ['▁', 'a', 've', 'n', 'ue'],\n",
+ " 'average': ['▁', 'a', 'ver', 'age'],\n",
+ " 'averages': ['▁', 'a', 'ver', 'age', 's'],\n",
+ " 'avert': ['▁', 'a', 'ver', 't'],\n",
+ " 'aviation': ['▁', 'a', 'vi', 'ation'],\n",
+ " 'avoid': ['▁', 'a', 'v', 'o', 'id'],\n",
+ " 'avoided': ['▁', 'a', 'v', 'o', 'id', 'ed'],\n",
+ " 'avon': ['▁', 'a', 'v', 'on'],\n",
+ " 'awake': ['▁', 'a', 'w', 'a', 'ke'],\n",
+ " 'awarded': ['▁', 'a', 'ward', 'ed'],\n",
+ " 'awards': ['▁', 'a', 'ward', 's'],\n",
+ " 'aware': ['▁', 'a', 'w', 'are'],\n",
+ " 'awareness': ['▁', 'a', 'w', 'are', 'ness'],\n",
+ " 'away': ['▁', 'a', 'way'],\n",
+ " 'awful': ['▁', 'a', 'w', 'ful'],\n",
+ " 'awfully': ['▁', 'a', 'w', 'ful', 'ly'],\n",
+ " 'b': ['▁', 'b'],\n",
+ " 'b.': ['▁', 'b', '.'],\n",
+ " 'b.b.c.': ['▁', 'b', '.', 'b', '.', 'c', '.'],\n",
+ " 'babe': ['▁', 'b', 'a', 'be'],\n",
+ " 'babel': ['▁', 'b', 'a', 'be', 'l'],\n",
+ " 'bably': ['▁', 'b', 'ably'],\n",
+ " 'baby': ['▁', 'b', 'a', 'by'],\n",
+ " \"baby's\": ['▁', 'b', 'a', 'by', \"'\", 's'],\n",
+ " 'back': ['▁', 'back'],\n",
+ " 'backbone': ['▁', 'back', 'b', 'one'],\n",
+ " 'backed': ['▁', 'back', 'ed'],\n",
+ " 'backers': ['▁', 'back', 'ers'],\n",
+ " 'background': ['▁', 'back', 'ground'],\n",
+ " 'backing': ['▁', 'back', 'ing'],\n",
+ " 'backstage': ['▁', 'back', 'st', 'age'],\n",
+ " 'backward': ['▁', 'back', 'ward'],\n",
+ " 'bad': ['▁', 'b', 'ad'],\n",
+ " 'badly': ['▁', 'b', 'ad', 'ly'],\n",
+ " 'baffled': ['▁', 'b', 'a', 'f', 'f', 'led'],\n",
+ " 'bag': ['▁', 'b', 'a', 'g'],\n",
+ " 'bagaya': ['▁', 'b', 'a', 'gay', 'a'],\n",
+ " 'baker': ['▁', 'b', 'a', 'k', 'er'],\n",
+ " 'balance': ['▁', 'b', 'al', 'ance'],\n",
+ " 'balance-sheet': ['▁', 'b', 'al', 'ance', '-', 'she', 'e', 't'],\n",
+ " 'balances': ['▁', 'b', 'al', 'ance', 's'],\n",
+ " 'bald': ['▁', 'b', 'al', 'd'],\n",
+ " 'ball': ['▁', 'b', 'all'],\n",
+ " 'balloon': ['▁', 'b', 'all', 'o', 'on'],\n",
+ " 'ballyhoo': ['▁', 'b', 'al', 'ly', 'ho', 'o'],\n",
+ " 'baltic': ['▁', 'b', 'al', 'tic'],\n",
+ " 'ban': ['▁', 'b', 'an'],\n",
+ " 'ban-': ['▁', 'b', 'an', '-'],\n",
+ " 'ban-the-': ['▁', 'b', 'an', '-', 'the', '-'],\n",
+ " 'ban-the-bomb': ['▁', 'b', 'an', '-', 'the', '-', 'bomb'],\n",
+ " 'bank': ['▁', 'bank'],\n",
+ " \"bank's\": ['▁', 'bank', \"'\", 's'],\n",
+ " 'banking': ['▁', 'bank', 'ing'],\n",
+ " 'bankrupt': ['▁', 'bank', 'r', 'up', 't'],\n",
+ " 'banks': ['▁', 'bank', 's'],\n",
+ " \"banks'\": ['▁', 'bank', 's', \"'\"],\n",
+ " 'banned': ['▁', 'b', 'an', 'n', 'ed'],\n",
+ " 'banzie': ['▁', 'b', 'an', 'z', 'i', 'e'],\n",
+ " 'bar': ['▁', 'b', 'ar'],\n",
+ " 'barb': ['▁', 'b', 'ar', 'b'],\n",
+ " 'barbara': ['▁', 'b', 'ar', 'b', 'ar', 'a'],\n",
+ " 'barbarously': ['▁', 'b', 'ar', 'b', 'ar', 'ous', 'ly'],\n",
+ " 'barclay': ['▁', 'b', 'ar', 'clay'],\n",
+ " 'bare': ['▁', 'b', 'are'],\n",
+ " 'bargain': ['▁', 'b', 'ar', 'g', 'a', 'in'],\n",
+ " 'bargaining': ['▁', 'b', 'ar', 'g', 'a', 'in', 'ing'],\n",
+ " 'bark': ['▁', 'b', 'ar', 'k'],\n",
+ " 'barrier': ['▁', 'b', 'ar', 'r', 'i', 'er'],\n",
+ " 'barriers': ['▁', 'b', 'ar', 'r', 'i', 'ers'],\n",
+ " 'barry': ['▁', 'b', 'a', 'rry'],\n",
+ " 'base': ['▁', 'base'],\n",
+ " 'based': ['▁', 'bas', 'ed'],\n",
+ " 'bases': ['▁', 'base', 's'],\n",
+ " 'basic': ['▁', 'bas', 'ic'],\n",
+ " 'basin': ['▁', 'bas', 'in'],\n",
+ " 'basing': ['▁', 'bas', 'ing'],\n",
+ " 'basis': ['▁', 'bas', 'is'],\n",
+ " 'baskerville': ['▁', 'bas', 'k', 'er', 'v', 'il', 'le'],\n",
+ " 'basses': ['▁', 'bas', 'se', 's'],\n",
+ " 'basting': ['▁', 'bas', 't', 'ing'],\n",
+ " 'bathing': ['▁', 'b', 'a', 'thing'],\n",
+ " 'bats': ['▁', 'b', 'at', 's'],\n",
+ " 'batsman': ['▁', 'b', 'at', 's', 'man'],\n",
+ " 'battalions': ['▁', 'b', 'at', 't', 'al', 'ion', 's'],\n",
+ " 'batting': ['▁', 'b', 'at', 't', 'ing'],\n",
+ " 'battle': ['▁', 'b', 'a', 'ttle'],\n",
+ " 'bavaria': ['▁', 'b', 'a', 'v', 'ar', 'i', 'a'],\n",
+ " 'bavarian': ['▁', 'b', 'a', 'v', 'ar', 'i', 'an'],\n",
+ " 'bavarians': ['▁', 'b', 'a', 'v', 'ar', 'i', 'an', 's'],\n",
+ " 'bay': ['▁', 'b', 'a', 'y'],\n",
+ " 'be': ['▁', 'be'],\n",
+ " 'beach': ['▁', 'b', 'each'],\n",
+ " 'beaches': ['▁', 'b', 'each', 'es'],\n",
+ " 'beacon': ['▁', 'be', 'a', 'con'],\n",
+ " 'beaks': ['▁', 'be', 'a', 'k', 's'],\n",
+ " 'bean': ['▁', 'be', 'an'],\n",
+ " 'bear': ['▁', 'be', 'ar'],\n",
+ " 'bearer': ['▁', 'be', 'are', 'r'],\n",
+ " 'bears': ['▁', 'be', 'ar', 's'],\n",
+ " 'beastly': ['▁', 'b', 'east', 'ly'],\n",
+ " 'beasts': ['▁', 'b', 'east', 's'],\n",
+ " 'beaten': ['▁', 'be', 'a', 'ten'],\n",
+ " 'beautiful': ['▁', 'be', 'a', 'u', 't', 'i', 'ful'],\n",
+ " 'beautifully': ['▁', 'be', 'a', 'u', 't', 'i', 'ful', 'ly'],\n",
+ " 'beauty': ['▁', 'be', 'a', 'u', 'ty'],\n",
+ " 'became': ['▁', 'be', 'came'],\n",
+ " 'because': ['▁', 'because'],\n",
+ " 'beckoning': ['▁', 'be', 'ck', 'on', 'ing'],\n",
+ " 'become': ['▁', 'be', 'come'],\n",
+ " 'becomes': ['▁', 'be', 'come', 's'],\n",
+ " 'becoming': ['▁', 'be', 'com', 'ing'],\n",
+ " 'bed': ['▁', 'b', 'ed'],\n",
+ " 'bedlam': ['▁', 'b', 'ed', 'la', 'm'],\n",
+ " 'beds': ['▁', 'b', 'ed', 's'],\n",
+ " 'bedspreads': ['▁', 'b', 'ed', 's', 'p', 'read', 's'],\n",
+ " 'beech': ['▁', 'be', 'e', 'ch'],\n",
+ " 'been': ['▁', 'been'],\n",
+ " 'before': ['▁', 'before'],\n",
+ " 'befriended': ['▁', 'be', 'friend', 'ed'],\n",
+ " 'began': ['▁', 'be', 'g', 'an'],\n",
+ " 'begin': ['▁', 'be', 'g', 'in'],\n",
+ " 'beginner': ['▁', 'be', 'g', 'in', 'n', 'er'],\n",
+ " 'beginning': ['▁', 'be', 'g', 'in', 'n', 'ing'],\n",
+ " 'begins': ['▁', 'be', 'g', 'in', 's'],\n",
+ " 'begun': ['▁', 'be', 'g', 'un'],\n",
+ " 'behan': ['▁', 'be', 'h', 'an'],\n",
+ " 'behave': ['▁', 'be', 'have'],\n",
+ " 'behaviour': ['▁', 'be', 'h', 'a', 'vi', 'our'],\n",
+ " 'behind': ['▁', 'behind'],\n",
+ " 'beier': ['▁', 'be', 'i', 'er'],\n",
+ " 'being': ['▁', 'being'],\n",
+ " 'belgian': ['▁', 'be', 'l', 'g', 'i', 'an'],\n",
+ " 'belgium': ['▁', 'be', 'l', 'giu', 'm'],\n",
+ " 'belgrade': ['▁', 'be', 'l', 'gr', 'a', 'de'],\n",
+ " 'belief': ['▁', 'be', 'li', 'e', 'f'],\n",
+ " 'believe': ['▁', 'believe'],\n",
+ " 'believed': ['▁', 'believed'],\n",
+ " 'believes': ['▁', 'believe', 's'],\n",
+ " 'bell': ['▁', 'be', 'll'],\n",
+ " \"bell's\": ['▁', 'be', 'll', \"'\", 's'],\n",
+ " 'belmondo': ['▁', 'be', 'l', 'mon', 'do'],\n",
+ " 'belonged': ['▁', 'be', 'long', 'ed'],\n",
+ " 'belongs': ['▁', 'be', 'long', 's'],\n",
+ " 'below': ['▁', 'be', 'low'],\n",
+ " 'belt': ['▁', 'be', 'l', 't'],\n",
+ " 'ben': ['▁', 'be', 'n'],\n",
+ " 'bench': ['▁', 'be', 'n', 'ch'],\n",
+ " 'benches': ['▁', 'be', 'n', 'che', 's'],\n",
+ " 'bend': ['▁', 'b', 'end'],\n",
+ " 'bending': ['▁', 'b', 'end', 'ing'],\n",
+ " 'benefits': ['▁', 'be', 'ne', 'f', 'its'],\n",
+ " 'bent': ['▁', 'b', 'ent'],\n",
+ " 'ber': ['▁', 'be', 'r'],\n",
+ " 'berlin': ['▁', 'berlin'],\n",
+ " \"berlin's\": ['▁', 'berlin', \"'\", 's'],\n",
+ " 'bernhard': ['▁', 'be', 'r', 'n', 'hard'],\n",
+ " 'berry': ['▁', 'be', 'rry'],\n",
+ " 'bertrand': ['▁', 'bert', 'r', 'and'],\n",
+ " 'beset': ['▁', 'be', 'set'],\n",
+ " 'beside': ['▁', 'be', 'side'],\n",
+ " 'best': ['▁', 'best'],\n",
+ " 'best-seller': ['▁', 'best', '-', 's', 'ell', 'er'],\n",
+ " 'bet': ['▁', 'be', 't'],\n",
+ " 'betjeman': ['▁', 'be', 't', 'je', 'man'],\n",
+ " 'betrayal': ['▁', 'be', 'tr', 'a', 'y', 'al'],\n",
+ " 'betrayed': ['▁', 'be', 'tr', 'a', 'y', 'ed'],\n",
+ " 'better': ['▁', 'better'],\n",
+ " 'better-': ['▁', 'better', '-'],\n",
+ " \"betti's\": ['▁', 'be', 't', 't', 'i', \"'\", 's'],\n",
+ " 'between': ['▁', 'between'],\n",
+ " 'bevel': ['▁', 'be', 've', 'l'],\n",
+ " 'bevelled': ['▁', 'be', 'v', 'ell', 'ed'],\n",
+ " 'beware': ['▁', 'be', 'w', 'are'],\n",
+ " 'bewildered': ['▁', 'be', 'w', 'il', 'd', 'er', 'ed'],\n",
+ " 'beyond': ['▁', 'beyond'],\n",
+ " 'bidet': ['▁', 'b', 'i', 'de', 't'],\n",
+ " 'big': ['▁', 'big'],\n",
+ " 'bigger': ['▁', 'big', 'g', 'er'],\n",
+ " 'biggest': ['▁', 'big', 'g', 'est'],\n",
+ " 'bill': ['▁', 'b', 'ill'],\n",
+ " 'bills': ['▁', 'b', 'ill', 's'],\n",
+ " 'binding': ['▁', 'b', 'in', 'd', 'ing'],\n",
+ " 'biological': ['▁', 'b', 'i', 'o', 'lo', 'g', 'ical'],\n",
+ " 'bird': ['▁', 'b', 'i', 'r', 'd'],\n",
+ " 'birds': ['▁', 'b', 'i', 'r', 'd', 's'],\n",
+ " 'bishop': ['▁', 'b', 'is', 'hop'],\n",
+ " 'bit': ['▁', 'b', 'it'],\n",
+ " 'bite': ['▁', 'b', 'it', 'e'],\n",
+ " 'bits': ['▁', 'b', 'its'],\n",
+ " 'bitter-sweet': ['▁', 'b', 'it', 'ter', '-', 's', 'we', 'e', 't'],\n",
+ " 'bitterest': ['▁', 'b', 'it', 'ter', 'est'],\n",
+ " 'bitterly': ['▁', 'b', 'it', 'ter', 'ly'],\n",
+ " 'bituminized': ['▁', 'b', 'it', 'um', 'in', 'i', 'z', 'ed'],\n",
+ " 'black': ['▁', 'bl', 'a', 'ck'],\n",
+ " 'black-': ['▁', 'bl', 'a', 'ck', '-'],\n",
+ " 'black-listed': ['▁', 'bl', 'a', 'ck', '-', 'li', 'st', 'ed'],\n",
+ " 'blackbird': ['▁', 'bl', 'a', 'ck', 'b', 'i', 'r', 'd'],\n",
+ " 'blacks': ['▁', 'bl', 'a', 'ck', 's'],\n",
+ " 'blame': ['▁', 'bl', 'a', 'me'],\n",
+ " 'blamed': ['▁', 'bl', 'am', 'ed'],\n",
+ " 'blander': ['▁', 'bl', 'and', 'er'],\n",
+ " 'blank': ['▁', 'bl', 'an', 'k'],\n",
+ " 'blend': ['▁', 'bl', 'end'],\n",
+ " 'blight': ['▁', 'b', 'light'],\n",
+ " 'blind': ['▁', 'bl', 'in', 'd'],\n",
+ " 'blinked': ['▁', 'bl', 'in', 'k', 'ed'],\n",
+ " 'block': ['▁', 'block'],\n",
+ " 'blocks': ['▁', 'block', 's'],\n",
+ " 'bloem-': ['▁', 'b', 'lo', 'e', 'm', '-'],\n",
+ " 'blond': ['▁', 'bl', 'on', 'd'],\n",
+ " 'blood': ['▁', 'b', 'lo', 'od'],\n",
+ " 'bloodstained': ['▁', 'b', 'lo', 'od', 's', 'tain', 'ed'],\n",
+ " 'bloody': ['▁', 'b', 'lo', 'od', 'y'],\n",
+ " 'blouse': ['▁', 'b', 'lo', 'use'],\n",
+ " 'blouses': ['▁', 'bl', 'ous', 'es'],\n",
+ " 'blow': ['▁', 'b', 'low'],\n",
+ " 'blowflies': ['▁', 'b', 'low', 'f', 'l', 'ies'],\n",
+ " 'blown': ['▁', 'bl', 'own'],\n",
+ " 'blue': ['▁', 'bl', 'ue'],\n",
+ " 'blunt': ['▁', 'bl', 'un', 't'],\n",
+ " 'bluntly': ['▁', 'bl', 'un', 't', 'ly'],\n",
+ " 'bluster': ['▁', 'bl', 'u', 'ster'],\n",
+ " 'board': ['▁', 'board'],\n",
+ " 'boat': ['▁', 'bo', 'at'],\n",
+ " 'boat-train': ['▁', 'bo', 'at', '-', 'train'],\n",
+ " 'bobby': ['▁', 'bo', 'b', 'by'],\n",
+ " 'bodies': ['▁', 'bo', 'd', 'ies'],\n",
+ " 'body': ['▁', 'body'],\n",
+ " 'boeing': ['▁', 'bo', 'e', 'ing'],\n",
+ " 'bogy': ['▁', 'bo', 'g', 'y'],\n",
+ " 'boiled': ['▁', 'bo', 'il', 'ed'],\n",
+ " 'boils': ['▁', 'bo', 'il', 's'],\n",
+ " 'bold': ['▁', 'b', 'old'],\n",
+ " 'boldly': ['▁', 'b', 'old', 'ly'],\n",
+ " 'bolt': ['▁', 'bo', 'l', 't'],\n",
+ " 'bolted': ['▁', 'bo', 'l', 'ted'],\n",
+ " 'bomb': ['▁', 'bomb'],\n",
+ " 'bombay': ['▁', 'bomb', 'a', 'y'],\n",
+ " 'bombed': ['▁', 'bomb', 'ed'],\n",
+ " 'bombers': ['▁', 'bomb', 'ers'],\n",
+ " 'bonded': ['▁', 'b', 'on', 'd', 'ed'],\n",
+ " 'bone': ['▁', 'b', 'one'],\n",
+ " 'bones': ['▁', 'b', 'one', 's'],\n",
+ " 'bonn': ['▁', 'b', 'on', 'n'],\n",
+ " \"bonn's\": ['▁', 'b', 'on', 'n', \"'\", 's'],\n",
+ " 'book': ['▁', 'book'],\n",
+ " 'booklet': ['▁', 'book', 'le', 't'],\n",
+ " 'books': ['▁', 'book', 's'],\n",
+ " 'booming': ['▁', 'bo', 'o', 'm', 'ing'],\n",
+ " 'border': ['▁', 'b', 'order'],\n",
+ " 'bore': ['▁', 'bo', 're'],\n",
+ " 'bored': ['▁', 'b', 'or', 'ed'],\n",
+ " 'boredom': ['▁', 'bo', 're', 'do', 'm'],\n",
+ " 'bores': ['▁', 'bo', 're', 's'],\n",
+ " 'born': ['▁', 'b', 'or', 'n'],\n",
+ " 'borough': ['▁', 'bo', 'rough'],\n",
+ " 'borrow': ['▁', 'b', 'or', 'ro', 'w'],\n",
+ " 'borstal': ['▁', 'b', 'or', 'st', 'al'],\n",
+ " 'bosoms': ['▁', 'bo', 'so', 'm', 's'],\n",
+ " 'bossed': ['▁', 'bo', 's', 's', 'ed'],\n",
+ " 'bosses': ['▁', 'bo', 's', 'se', 's'],\n",
+ " 'both': ['▁', 'both'],\n",
+ " 'bottle': ['▁', 'bo', 'ttle'],\n",
+ " 'bottom': ['▁', 'bo', 't', 'to', 'm'],\n",
+ " 'bought': ['▁', 'bo', 'ug', 'h', 't'],\n",
+ " 'boun': ['▁', 'bo', 'un'],\n",
+ " 'bound': ['▁', 'b', 'ound'],\n",
+ " 'boutiques': ['▁', 'b', 'out', 'i', 'q', 'ue', 's'],\n",
+ " 'bow': ['▁', 'bo', 'w'],\n",
+ " 'bow-street': ['▁', 'bo', 'w', '-', 'st', 're', 'e', 't'],\n",
+ " 'bowed': ['▁', 'bo', 'w', 'ed'],\n",
+ " 'bowing': ['▁', 'bo', 'w', 'ing'],\n",
+ " 'bows': ['▁', 'bo', 'w', 's'],\n",
+ " 'box': ['▁', 'bo', 'x'],\n",
+ " 'boxes': ['▁', 'bo', 'x', 'es'],\n",
+ " 'boxing': ['▁', 'bo', 'x', 'ing'],\n",
+ " 'boy': ['▁', 'bo', 'y'],\n",
+ " 'boycotted': ['▁', 'bo', 'y', 'cott', 'ed'],\n",
+ " 'boycotting': ['▁', 'bo', 'y', 'cott', 'ing'],\n",
+ " 'boyd-orr': ['▁', 'bo', 'y', 'd', '-', 'or', 'r'],\n",
+ " 'boyle': ['▁', 'bo', 'y', 'le'],\n",
+ " 'boys': ['▁', 'bo', 'y', 's'],\n",
+ " 'braces': ['▁', 'br', 'a', 'ce', 's'],\n",
+ " 'brain': ['▁', 'b', 'rain'],\n",
+ " 'brain-activity': ['▁', 'b', 'rain', '-', 'act', 'i', 'v', 'ity'],\n",
+ " 'brain-children': ['▁', 'b', 'rain', '-', 'children'],\n",
+ " 'brains': ['▁', 'b', 'rain', 's'],\n",
+ " 'brandy': ['▁', 'br', 'and', 'y'],\n",
+ " 'brash': ['▁', 'br', 'as', 'h'],\n",
+ " 'brass': ['▁', 'br', 'as', 's'],\n",
+ " 'brauchitsch': ['▁', 'br', 'a', 'u', 'ch', 'its', 'ch'],\n",
+ " 'breach': ['▁', 'br', 'each'],\n",
+ " 'bread-and-butter': ['▁', 'b', 'read', '-', 'and', '-', 'but', 'ter'],\n",
+ " 'break': ['▁', 'b', 're', 'a', 'k'],\n",
+ " 'breaking': ['▁', 'b', 're', 'a', 'k', 'ing'],\n",
+ " 'breaks': ['▁', 'b', 're', 'a', 'k', 's'],\n",
+ " 'breath': ['▁', 'b', 're', 'a', 'th'],\n",
+ " 'breathing': ['▁', 'b', 're', 'a', 'thing'],\n",
+ " 'breathless': ['▁', 'b', 're', 'a', 'th', 'less'],\n",
+ " 'breeding': ['▁', 'b', 're', 'ed', 'ing'],\n",
+ " 'breezily': ['▁', 'b', 're', 'e', 'z', 'i', 'ly'],\n",
+ " 'brehm': ['▁', 'b', 're', 'h', 'm'],\n",
+ " 'brella': ['▁', 'br', 'ell', 'a'],\n",
+ " 'brenda': ['▁', 'br', 'end', 'a'],\n",
+ " 'brendan': ['▁', 'br', 'end', 'an'],\n",
+ " \"brendan's\": ['▁', 'br', 'end', 'an', \"'\", 's'],\n",
+ " 'brentano': ['▁', 'br', 'ent', 'a', 'no'],\n",
+ " 'brezhnev': ['▁', 'b', 're', 'z', 'h', 'ne', 'v'],\n",
+ " 'brian': ['▁', 'br', 'i', 'an'],\n",
+ " 'bridal': ['▁', 'br', 'id', 'al'],\n",
+ " 'bride': ['▁', 'br', 'i', 'de'],\n",
+ " 'brief': ['▁', 'brief'],\n",
+ " 'brief-': ['▁', 'brief', '-'],\n",
+ " 'briefcase': ['▁', 'brief', 'case'],\n",
+ " 'briefing': ['▁', 'brief', 'ing'],\n",
+ " 'brigadiers': ['▁', 'br', 'i', 'g', 'ad', 'i', 'ers'],\n",
+ " 'bright': ['▁', 'b', 'right'],\n",
+ " 'brighter': ['▁', 'b', 'right', 'er'],\n",
+ " 'brightly': ['▁', 'b', 'right', 'ly'],\n",
+ " \"brighton's\": ['▁', 'b', 'right', 'on', \"'\", 's'],\n",
+ " 'brilliant': ['▁', 'br', 'ill', 'i', 'ant'],\n",
+ " 'brilliantly': ['▁', 'br', 'ill', 'i', 'ant', 'ly'],\n",
+ " 'bring': ['▁', 'br', 'ing'],\n",
+ " 'brings': ['▁', 'br', 'ing', 's'],\n",
+ " 'bristled': ['▁', 'br', 'is', 't', 'led'],\n",
+ " 'bristol': ['▁', 'br', 'is', 'to', 'l'],\n",
+ " 'britain': ['▁', 'britain'],\n",
+ " \"britain's\": ['▁', 'britain', \"'\", 's'],\n",
+ " 'british': ['▁', 'british'],\n",
+ " 'british-owned': ['▁', 'british', '-', 'own', 'ed'],\n",
+ " 'britishers': ['▁', 'british', 'ers'],\n",
+ " 'brittle': ['▁', 'br', 'i', 'ttle'],\n",
+ " 'broad': ['▁', 'b', 'ro', 'ad'],\n",
+ " 'broadcast': ['▁', 'b', 'ro', 'ad', 'c', 'a', 'st'],\n",
+ " 'broadcasting': ['▁', 'b', 'ro', 'ad', 'c', 'a', 'st', 'ing'],\n",
+ " 'broke': ['▁', 'b', 'ro', 'ke'],\n",
+ " 'broken': ['▁', 'b', 'ro', 'k', 'en'],\n",
+ " 'bronx': ['▁', 'br', 'on', 'x'],\n",
+ " \"brook's\": ['▁', 'b', 'ro', 'o', 'k', \"'\", 's'],\n",
+ " 'brother': ['▁', 'brother'],\n",
+ " 'brother-': ['▁', 'brother', '-'],\n",
+ " 'brother-in-law': ['▁', 'brother', '-', 'in', '-', 'law'],\n",
+ " 'brought': ['▁', 'brought'],\n",
+ " 'brown': ['▁', 'brown'],\n",
+ " \"brown's\": ['▁', 'brown', \"'\", 's'],\n",
+ " 'bru\"cke': ['▁', 'br', 'u', '\"', 'ck', 'e'],\n",
+ " 'bruce': ['▁', 'br', 'u', 'ce'],\n",
+ " 'bruno': ['▁', 'br', 'un', 'o'],\n",
+ " 'brunswick': ['▁', 'br', 'un', 's', 'w', 'i', 'ck'],\n",
+ " 'brussels': ['▁', 'br', 'us', 's', 'el', 's'],\n",
+ " 'brutal': ['▁', 'br', 'u', 't', 'al'],\n",
+ " 'bryan': ['▁', 'br', 'y', 'an'],\n",
+ " 'bu\"ckerei': ['▁', 'b', 'u', '\"', 'ck', 'e', 're', 'i'],\n",
+ " 'buck': ['▁', 'b', 'u', 'ck'],\n",
+ " 'buckingham': ['▁', 'b', 'u', 'ck', 'ing', 'h', 'am'],\n",
+ " 'buckley': ['▁', 'b', 'u', 'ck', 'le', 'y'],\n",
+ " 'budge': ['▁', 'b', 'ud', 'g', 'e'],\n",
+ " 'budgerigar': ['▁', 'b', 'ud', 'g', 'er', 'i', 'g', 'ar'],\n",
+ " 'budget': ['▁', 'budget'],\n",
+ " 'budgetary': ['▁', 'budget', 'ary'],\n",
+ " 'budgette': ['▁', 'budget', 'te'],\n",
+ " 'buganda': ['▁', 'b', 'ug', 'and', 'a'],\n",
+ " 'build': ['▁', 'b', 'u', 'il', 'd'],\n",
+ " 'building': ['▁', 'building'],\n",
+ " ...}"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "lex"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/notebooks/07-try-gtn.ipynb b/src/notebooks/07-try-gtn.ipynb
new file mode 100644
index 0000000..d366dec
--- /dev/null
+++ b/src/notebooks/07-try-gtn.ipynb
@@ -0,0 +1,155 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import gtn\n",
+ "from IPython.display import display, Image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Make some graphs:\n",
+ "g1 = gtn.Graph()\n",
+ "g1.add_node(True) # Add a start node\n",
+ "g1.add_node() # Add an internal node\n",
+ "g1.add_node(False, True) # Add an accepting node\n",
+ "\n",
+ "\n",
+ "# Add arcs with (src node, dst node, label):\n",
+ "g1.add_arc(0, 1, 1)\n",
+ "g1.add_arc(0, 1, 2)\n",
+ "g1.add_arc(1, 2, 1)\n",
+ "g1.add_arc(1, 2, 0)\n",
+ "\n",
+ "\n",
+ "g2 = gtn.Graph()\n",
+ "g2.add_node(True, True)\n",
+ "g2.add_arc(0, 0, 1)\n",
+ "g2.add_arc(0, 0, 0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<IPython.core.display.Image object>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEUAAACdCAIAAABtgiI8AAAABmJLR0QA/wD/AP+gvaeTAAAUNElEQVR4nO1ceVAUx/fvPTmWYwmwqyzKJceKWhohYqLghUlhIhCCV4mIBwYlIFoQSEw0KSGKJikTQ1kaIFErIPEoxQNRAdEEDyCBIJeIoNzIsRwL7DW/P94vXcPucu0OfqkUn79237yZeZ/pmdevX79uGkEQ6D8E+v/aAIoxyWdiY5LPxMYkn4mNST4TG5N8JjYm+UxsTPKZ2JjkM7ExyWdiY5LPxMYkn4mNST4TG5N8JjaYFF4rKSkpIyPDwcGhubl52bJl69evJx99+vSpUChsaGjg8XjDa2oFgiJ8/fXX1tbWHR0dBEF0dHRYW1sfO3aMrPDVV1+tXLlyNJragBo+L168YLFY33zzDZbExsbq6+u/evUKSxwdHZOTk0ejqQ2o4RMXF4cQevToEZbk5eUhhA4fPgx/CwoKdHR0Ojs7R9TUEtT4g/v37yOELC0tsWTatGkIoaKiIvibkpLi5eVlbGw8oqaWoIZPQ0MDQsjExARL3njjDYTQ8+fPEUIEQaSmpsJHP7ym9qCGj5GREUKIRqNhCfyWSCQIodzc3M7Ozvfff39ETe1BDR8nJyeEUGdnJ5Z0dHQghCwsLBBCKSkpPj4+enp6I2pqD2r4ODs7o3/fJUBjYyNCaNGiRVKp9Pz587iHGUaTEkuo8W/t7e1cLvfbb7/FkiNHjrDZ7JcvX6anp5uamkokkhE1KbGEsv708OHD9vb23d3dBEF0dXXZ29t//fXXBEFs2LBhx44do9GkBDSCuvn6pKSknJyc6dOnV1ZWenp6bt++XSwW83i8a9eueXh4DK9JlQ1U8pkI+K/F15N8JjaoHP+ohVwur66urqioMDU1dXZ2hvhgHEGVo1SFTCb7+eefp0+fju+lr68fHBzc1tY2fjcdLz79/f1+fn5sNjs0NLSwsLC3t7euru7EiRPTpk2bNm1acXHxON13XPgoFApvb28ul3v//n2lQ21tbR4eHhYWFjU1NeNx63Hhk5CQwGQyVckAOjs7nZ2dPTw8FAoF5bemnk9DQwOHw/niiy+G0SkoKGAymYmJiZTfnXo+u3fvFggE/f39w6uFhIRYWVkNDAxQe3eK+bS0tOjr6//www8jatbW1rLZ7F9++YVaA6jkU19fv2HDBhaLtWDBAtWjlZWVDAajubmZIIjExER/f3+hUDhlypTffvuNQhsobp+5c+cihJycnFQPqebfrl+/TqPRLC0tJ1z+DYBzGmr5qObfpFKpiYnJqlWrKMy/URm/3bp1i8PhqD1UWFhYU1Pj6+t79uxZqVS6fPlyhBCTyXzvvfeam5vFYnFiYiIlNlDJ5/bt24sXL1Z7aKj824oVK548eYImWv4NkJOTs2zZMlU5MXT+zcXFpa+vD020/BtC6OXLly0tLW+99ZbqoWHyb87Ozvr6+mii5d8QQn///TeNRpszZ47qoWHybwwGQygUoomWf0MIFRcXT58+nZzIBYyYf7Ozs0PU5d8o41NZWTlz5kxV+c2bNxFCnp6e8DcgIIDL5WZnZ2MFhUKBENqwYQMlZlDGp7q62tbWFj5uuVyO5SkpKR999BGLxYK/JiYmMTExJ06c6OnpQQh1d3fDfMlQjn6soCxfJRAIvL295XL5yZMnWSxWbGzsypUr7e3tR8y/zZo1a//+/cXFxbNnz6bADkp65f7+fhqNduHCBQ3OBd9w48YNSiyh5n1raWkhCGLq1KkanGtsbGxoaFhXV0eJJdTwefXqFULIzMxMs9MtLS3r6+spsYQaPm1tbUgLPlOnTm1qaqLEEmr4tLe3MxgMLper9uirV69u3LhRVVU11OkcDkcsFlNiCTV8ent79fT0yLOIGBcvXrS1tfXy8rK3t//kk08Ide5UV1cXHL32oIaPWCyGcEYJJSUl69ev37hxo0gkSktLO3ny5JEjR1TV9PT0qOJDjb+Oj4+3srJSlfv6+rq6usrlcvgbGxtrYGDQ2tqqpLZjx47ly5dTYgk17TMwMKCrq6skfPHiRXp6emRkJJ3+/3cJDw/X1dVVHbpR2D7U8CHUfRW//vorj8fz8fHBEg6Hs2bNmpSUFCXNCff90Ghq4qasrKx3330XR24Ab2/voqIicnyNEGIwGBCVao/xmv/p7+9/8OCBUtiGEHrnnXfYbHZOTo6Sso6ODiX3pYYPnU5XesD5+fn9/f3u7u5KmhwOx8nJqbi4mCyUSCQTiw+bzVYaMJeWlhoaGlpbW6sqC4XC8vJysmRgYGBi8dHX11fq4CsrKx0cHNT2sEKhsLS0lCyhkA818436+vpKDqqiosLBwQF+9/X1/fHHH1KpFMZwDQ0Nz549O3r0qFwuB2FBQQFW1haU9GLnzp2DTwhLZs+e/fnnn8NvuVxOnnVkMpl0Op1Go9FoNPhhaGj49ttv5+bmVldXjzgxMTwoqxdTKBTw+AGtra08Hg9+0+n0oKAg7LhlMhlmDj/6+vry8vLc3d1tbW11dXXNzc1TU1M1s4QaPqampujfUQNCiCCItrY28vBh8+bNMplsqNNlMhlB6r4kEsmqVas0s4RKPjCqQwiJRCKpVGpubo4VrK2t3d3dGQzGiJdiMpkRERGGhoaaWUINH2gK3D6QElDKxW3fvn00QQCTyQwNDdXYEmr4GBoastns1tZW+Nvf348QUopQ/fz8DAwMhr8Oi8UKDQ3VeJyLKIzfLCwscA4A+Ch1Kbq6uhs3blQK51QRERGhjSWUxW/knMbAwABSaR+EUFBQkFQqxX+VelsWi7V161YtE9lU8sE5J+DDZrOVdFxdXWfOnIlpEARBbi65XL53714tzaCMj0AgGE3Oadu2bdjLTZkyBTtxFou1fv36GTNmaGkGZXysra3xnBQ8dfKrhbFp0yZoHzabvXbtWtztyGSy6Oho7c2gjM+MGTNaW1uhmJrJZKIh+Jiamq5atYpOp0skkoCAgClTpiCEWCzW6tWrZ82apb0ZlNW/QUBZVVVlZmYGcyTBwcENDQ0ikUgsFnd0dHA4HA6HY2RkxOFwFAqFsbGxTCbz8PBIS0uTSqX79u2jxg5tgj8MhUKRnZ1Np9PJMcFoAFmumTNn9vT0UGKJtvMlnZ2dCQkJp06dqqmpUTrEZDJtbGzMzMwMDAy4XG5vb29vb29HR0d1dTU5cgUYGBj4+/vv3bsXJvA0h8ZPor29PSYmRql+0srKKigo6PTp02VlZbgmXhUPHjy4evVqVFSUq6sruRei0+m+vr5FRUUaW6UJH4VCAbkobAeXyw0JCfnzzz81uFp+fn5cXBzME+OGDQ8PF4lEGlxtzO9bQ0NDQEBAVlYW/OXz+RERETt37lSKiJ8/f15SUlJRUVFfX4/9gb6+Po/Hc3BwcHR0nDNnDrhB/Jpcvnw5Li7u8ePHILGwsDhz5ozagobhMCb2N2/exM3CZrOjo6N7e3vxUYlEcuXKlcDAQFhyNTwMDAy8vLyOHz+ulP5NTU0VCASgw2Aw9u/fD6OjUWIMfBISEnDm1tXVtbS0FB9qamqKjo4mv4GjB5vN9vPze/jwIb5aV1dXSEgIVvDx8enr66OYz6FDh/ANgoODcV2hSCTas2eP0uSCnZ3dtm3bTp06dffu3bq6uvb2doIgenp6mpubCwsLU1NTY2Ji3NzcyO8bQsjT0/Off/7Bd7x06RIeQbm5uY2yAGtUfA4cOADXZTKZycnJWH7+/HlyOMzn86Oiop48eTLKZ9TW1vbTTz+5uLjgK7BYrMjISNwaZWVlOJHi5uY2mj5qZD4JCQlwRT09vStXroBQLBYHBwdjOwQCwfHjx8Vi8SiZKOHWrVvkwqy5c+dWVFTAobq6OtwjeXl5DdMHjIpPZmYmhMNMJhOTaWlpwQ+VwWBERER0dXVpxgRDoVCcOXOGz+fDZY2MjO7cuQOHGhsbbW1tQb5r1y7N+cBaa4QQjUbDhau1tbWOjo5wdQsLi5ycHC2ZkNHc3IwrY3R0dHBBw9OnT7GzSUtLG+YKQ/JRKBTY94eHh4OwpaUFk5k/f35TUxOFZAByuXzPnj34c8J1Fnfu3IE3xdjYeJja+iH5JCcnw0VdXFzAm4nFYvyaLVmyRPt3bBhgd8rhcPLz80G4f/9+EPr4+Ax1ono+7e3t0L5sNhv3M9gBuLi4jCsZwGeffQa3s7KygiUpMpls3rx5IExPT1d7lno+eKgYHR0NkvPnz2NXBjXU4w2FQrF27Vq46bp160CYl5cHfbpQKMTTzGSo4dPR0QFRM5/PB5ff1dUFMQidTsdu5zWgu7sbx6nXrl0DYUBAAEh+//131VPU8Dl48CCccOjQIZDgnFhERMT4Wa8W9+7dgwGFvb09dD5lZWXQRG+++aaqvjIfhUIBk2pcLhci9qamJghnLCwsXsNno4rAwEB4mklJSSDx9/cHCXYVGMp88ExtSEgISPC39OOPP4636WpRW1sLCSMHBwf4ZjIzM5U6EgxlPtiJweBMIpGAo+Pz+RqHM9pjy5YtYNXt27cJgpDL5fA983g8Ja+gzMfGxgZcJMw0Xb58GS706aefvjbrVfHo0SMwIzAwECS7d+8GSUFBAVlzEB+c09iyZQtINm3aBJKSkpLXYvmQgDJtyHIRBHHlyhUw7OjRo2S1QfnEu3fvwo8lS5bAD/ic7OzsNE67JCUlrVmzZt++fdu3b1etdBk9PvjgA4SQSCQqKChACHl4eMDwKTc3d5AemVxUVBQIy8vLCYKorq6Gv9u2bdPsoVK41U5GRgYYEx8fDxIo954xYwZZbRAfb29vhBCTyYSADbfpqVOnNLCA2q122tvbwRj8CUGlE4PBIE+JD3rfoCTSxsYGpjpwFQeOqccE8lIfwLJlyzRe6mNiYgKetqKigmyVXC4nr00ZxAey6Thni+c/wOmNFZRvtQNm4FkmbCd5d5VBfCANizNpOCur2SJyyrfaMTY2Rgh1d3crWYUlSIlPb28vIi0lwCU5mi0uoHyrHZhOxk9Z9bkjJT4wgwuzhYg0YYglYwLlW+0oTZvjiiHyRO0gPvAAcPOpfQCjB+Vb7YBh2CpVO5ESH3hD8BOFqg9EcgxjgupSn6ysLDabrfFSH3g02Cps55B8IHn37NkzgiDQv1NuCKHKykoNbq+61OfkyZP79u0je7zRQyKRQDiGOw9ccG9lZYXVBmVcHR0ds7Kyent7GxoaBAIBHhsWFRXhoe+YEBUVZWZmtnPnTljqExkZqfFWOyUlJbBMClsFHRGfzx+0zoDcBx87dgyE169fJwhCKpVCUy5cuFCDHp1axMfHg21Xr14l2+bh4UFWG/S+ubm5wQ8ITJlMJny7jx8/xrVG/yvAGI7JZEJmOD8/H/zBwoULyWqD+MyfPx/aDk9XQR2aTCY7d+7cazFbPRobG8GvLF68GJwWtnDp0qWDVJWadfXq1QghOp3+4sULgiBaW1uhF3JxcXldb5Ya4PQi3tJiwYIFCCEdHR2lSQdlPqdPn4YzcVzs5+cHEhjrvn709fVB2YWhoSGkaCorKyHU8PX1VVJW5tPT0wO9qpOTEwy5Hz58CHzc3d3HYweTEfH999+DAVFRUSCJiYkBycWLF5WU1eTfgoKCQPvSpUsgWblyJUjOnj07rqaroqmpCcJQDocD6X+RSAQxrrm5uWoxsBo+paWlOGEHDVJSUgIZIz6fPx5zCsMAp9ri4uJAEhsbC5KDBw+q6qvPX/v6+sI5586dA0lkZCRIPD091SaOxwN4atDJyQmGzM3NzdA4xsbGMIxXgno+RUVFkG0QCASQE+3r64O9QRBCe/bsGVcagKysLIj3dXR0cFIK50rVNg4xzPxPeHg4nIkTpRUVFXgIRc4KjAcKCgrgs0EIHT9+HISZmZk4lz1UGf2QfEQiER6o4Ex+dnY2HmyEhYWNk7vLzs7GZPCEKXny8+bNm0OdO9z86Z07d8AxGBsbl5WVgfDChQu46HPdunWUZ+gTEhJwYXBgYCA8MvJSorCwsGFOH2F++8svv4SrWFpa1tbWgjAjIwOPwO3t7e/du0cJk8bGRuzNEEIRERFARi6Xr1mzBoQuLi7DL9gYgY9MJoOkHELI2dm5sbER5Pn5+XgOnUajbd68GbPVAGKx+LvvvsPvmI6ODv5m5HL5xx9/DPLR7LI2cj1FX18fbmsbGxtc6dDZ2blu3Tr8ONls9tatW8kbW48G9fX1hw4dgnAGu2bszQYGBvAtjIyM/vrrrxEvOKp6l46ODjyU4PF45CnHjIwMe3t7RIJQKIyKisrIyFDbP4CV+fn5R44cWbFiBXmFhr6+fmxsLK4MamhowM9R7dZ4mvMhCKKnp8fLywuuzmAwDhw4gMu4JBJJcnKy2gVWPB7Pzc1txYoV/v7+Xl5e7u7udnZ2qqtMDA0NIyMjyZFHZmYmLqCwsLAYfcXiGOrFJBLJrl27sBHz5s3Ly8vDR+Vy+e3btzdt2oQ/gxHBZDKXLl2amJhILkVsbm4ODAzEWTsXF5cx7Uw45nrLtLQ0bDGdTg8ICMCuHCCTyR4+fBgfHx8YGLhgwQJLS0uIUDgcDp/PnzVr1ocffhgTE5Oenq7k60UiUVxcHM6n0mi0sLCwsW54p0n9aE1NDXZ6wMrf3//WrVsax3Xl5eUxMTHkzLC9vf0wneYw0Ly+Nz09HebMMAQCwe7du9PT00dTyiqVSh88eBAXFwcjTQwul3vw4EGNVwVqVX+tUCguXboUFxdXWFio9GE4ODg4OTk5ODiYmZkZGhqamJj09PT09PR0dnZWVVWVl5eXlZWR8+gIIXNz8/Dw8NDQ0NF/gWqg2WNQQn5+flhYmGb1ozo6Oj4+PhcvXqRkb0gq979WKBRFRUVZWVm5ubmlpaXPnz8nb5REBo/HEwqFbm5uS5cuXbRoEVWbI6Fx3c9bIpFUVVV1dXXBa8bhcAwMDGBR91A7p2iPyf3JJzb+a3z+D3Ww9w5uHkfIAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "<IPython.core.display.Image object>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "gtn.draw(g1, \"g1.png\")\n",
+ "gtn.draw(g2, \"g2.png\")\n",
+ "display(Image(\"g1.png\"), Image(\"g2.png\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<IPython.core.display.Image object>"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "intersect = gtn.intersect(g1, g2)\n",
+ "gtn.draw(intersect, \"intersect.png\")\n",
+ "Image(\"intersect.png\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1.0, 0.0, 1.0, 0.0]\n"
+ ]
+ }
+ ],
+ "source": [
+ "score = gtn.viterbi_score(intersect)\n",
+ "gtn.backward(score)\n",
+ "\n",
+ "# print gradients of arc weights \n",
+ "print(g1.grad().weights_to_list()) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/notebooks/g1.png b/src/notebooks/g1.png
new file mode 100644
index 0000000..09dd49e
--- /dev/null
+++ b/src/notebooks/g1.png
Binary files differ
diff --git a/src/notebooks/g2.png b/src/notebooks/g2.png
new file mode 100644
index 0000000..a3cf21e
--- /dev/null
+++ b/src/notebooks/g2.png
Binary files differ
diff --git a/src/notebooks/intersect.png b/src/notebooks/intersect.png
new file mode 100644
index 0000000..63b7f2f
--- /dev/null
+++ b/src/notebooks/intersect.png
Binary files differ
diff --git a/src/tasks/build_transitions.py b/src/tasks/build_transitions.py
new file mode 100644
index 0000000..b12c9bc
--- /dev/null
+++ b/src/tasks/build_transitions.py
@@ -0,0 +1,263 @@
+"""Builds transition graph.
+
+Most code stolen from here:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/scripts/build_transitions.py
+
+"""
+
+import collections
+import itertools
+from pathlib import Path
+from typing import Dict, List, Optional
+
+import click
+import gtn
+from loguru import logger
+
+
+START_IDX = -1
+END_IDX = -2
+WORDSEP = "_"
+
+
+def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph:
+ """Returns a gtn Graph based on the ngrams."""
+ graph = gtn.Graph(False)
+ ngram = len(ngrams)
+ state_to_node = {}
+
+ def get_node(state: Optional[List]) -> gtn.node:
+ node = state_to_node.get(state, None)
+
+ if node is not None:
+ return node
+
+ start = state == tuple([START_IDX]) if ngram > 1 else True
+ end = state == tuple([END_IDX]) if ngram > 1 else True
+ node = graph.add_node(start, end)
+ state_to_node[state] = node
+
+ if not disable_backoff and not end:
+ # Add back off when adding node.
+ for n in range(1, len(state) + 1):
+ backoff_node = state_to_node.get(state[n:], None)
+
+ # Epsilon transition to the back-off state.
+ if backoff_node is not None:
+ graph.add_arc(node, backoff_node, gtn.epsilon)
+ break
+ return node
+
+ for grams in ngrams:
+ for gram in grams:
+ istate, ostate = gram[:-1], gram[len(gram) - ngram + 1 :]
+ inode = get_node(istate)
+
+ if END_IDX not in gram[1:] and gram[1:] not in state_to_node:
+ raise ValueError(
+ "Ill formed counts: if (x, y_1, ..., y_{n-1}) is above"
+ "the n-gram threshold, then (y_1, ..., y_{n-1}) must be"
+ "above the (n-1)-gram threshold"
+ )
+
+ if END_IDX in ostate:
+ # Merge all state having </s> into one as final graph generated
+ # will be similar.
+ ostate = tuple([END_IDX])
+
+ onode = get_node(ostate)
+ # p(gram[-1] | gram[:-1])
+ graph.add_arc(
+ inode, onode, gtn.epsilon if gram[-1] == END_IDX else gram[-1]
+ )
+ return graph
+
+
+def count_ngrams(lines: List, ngram: List, tokens_to_index: Dict) -> List:
+ """Counts the number of ngrams."""
+ counts = [collections.Counter() for _ in range(ngram)]
+ for line in lines:
+ # Prepend implicit start token.
+ token_line = [START_IDX]
+ for t in line:
+ token_line.append(tokens_to_index[t])
+ token_line.append(END_IDX)
+ for n, counter in enumerate(counts):
+ start_offset = n == 0
+ end_offset = ngram == 1
+ for e in range(n + start_offset, len(token_line) - end_offset):
+ counter[tuple(token_line[e - n : e + 1])] += 1
+
+ return counts
+
+
+def prune_ngrams(ngrams: List, prune: List) -> List:
+ """Prunes ngrams."""
+ pruned_ngrams = []
+ for n, grams in enumerate(ngrams):
+ grams = grams.most_common()
+ pruned_grams = [gram for gram, c in grams if c > prune[n]]
+ pruned_ngrams.append(pruned_grams)
+ return pruned_ngrams
+
+
+def add_blank_grams(pruned_ngrams: List, num_tokens: int, blank: str) -> List:
+ """Adds blank token to grams."""
+ all_grams = [gram for grams in pruned_ngrams for gram in grams]
+ maxorder = len(pruned_ngrams)
+ blank_grams = {}
+ if blank == "forced":
+ pruned_ngrams = [pruned_ngrams[0] if i == 0 else [] for i in range(maxorder)]
+ pruned_ngrams[0].append(tuple([num_tokens]))
+ blank_grams[tuple([num_tokens])] = True
+
+ for gram in all_grams:
+ # Iterate over all possibilities by using a vector of 0s, 1s to
+ # denote whether a blank is being used at each position.
+ if blank == "optional":
+ # Given a gram ab.. if order n, we have n + 1 positions
+ # available whether to use blank or not.
+ onehot_vectors = itertools.product([0, 1], repeat=len(gram) + 1)
+ elif blank == "forced":
+ # Must include a blank token in between.
+ onehot_vectors = [[1] * (len(gram) + 1)]
+ else:
+ raise ValueError(
+ "Invalid value specificed for blank. Must be in |optional|forced|none|"
+ )
+
+ for j in onehot_vectors:
+ new_array = []
+ for idx, oz in enumerate(j[:-1]):
+ if oz == 1 and gram[idx] != START_IDX:
+ new_array.append(num_tokens)
+ new_array.append(gram[idx])
+ if j[-1] == 1 and gram[-1] != END_IDX:
+ new_array.append(num_tokens)
+ for n in range(maxorder):
+ for e in range(n, len(new_array)):
+ cur_gram = tuple(new_array[e - n : e + 1])
+ if num_tokens in cur_gram and cur_gram not in blank_grams:
+ pruned_ngrams[n].append(cur_gram)
+ blank_grams[cur_gram] = True
+
+ return pruned_ngrams
+
+
+def add_self_loops(pruned_ngrams: List) -> List:
+ """Adds self loops to the ngrams."""
+ maxorder = len(pruned_ngrams)
+
+ # Use dict for fast search.
+ all_grams = set([gram for grams in pruned_ngrams for gram in grams])
+ for o in range(1, maxorder):
+ for gram in pruned_ngrams[o - 1]:
+ # Repeat one of the tokens.
+ for pos in range(len(gram)):
+ if gram[pos] == START_IDX or gram[pos] == END_IDX:
+ continue
+ new_gram = gram[:pos] + (gram[pos],) + gram[pos:]
+
+ if new_gram not in all_grams:
+ pruned_ngrams[o].append(new_gram)
+ all_grams.add(new_gram)
+ return pruned_ngrams
+
+
+def parse_lines(lines: List, lexicon: Path) -> List:
+ """Parses lines with a lexicon."""
+ with open(lexicon, "r") as f:
+ lex = (line.strip().split() for line in f)
+ lex = {line[0]: line[1:] for line in lex}
+ print(len(lex))
+ return [[t for w in line.split(WORDSEP) for t in lex[w]] for line in lines]
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to dataset root.")
+@click.option(
+ "--tokens", type=str, help="Path to token list (in order used with training)."
+)
+@click.option("--lexicon", type=str, default=None, help="Path to lexicon")
+@click.option(
+ "--prune",
+ nargs=2,
+ type=int,
+ help="Threshold values for prune unigrams, bigrams, etc.",
+)
+@click.option(
+ "--blank",
+ default=click.Choice(["none", "optional", "forced"]),
+ help="Specifies the usage of blank token"
+ "'none' - do not use blank token "
+ "'optional' - allow an optional blank inbetween tokens"
+ "'forced' - force a blank inbetween tokens (also referred to as garbage token)",
+)
+@click.option("--self_loops", is_flag=True, help="Add self loops for tokens")
+@click.option("--disable_backoff", is_flag=True, help="Disable backoff transitions")
+@click.option("--save_path", default=None, help="Path to save transition graph.")
+def cli(
+ data_dir: str,
+ tokens: str,
+ lexicon: str,
+ prune: List[int],
+ blank: str,
+ self_loops: bool,
+ disable_backoff: bool,
+ save_path: str,
+) -> None:
+ """CLI for creating the transitions."""
+ logger.info(f"Building {len(prune)}-gram transition models.")
+
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ # Build table of counts and the back-off if below threshold.
+ with open(data_dir / "train.txt", "r") as f:
+ lines = [line.strip() for line in f]
+
+ with open(data_dir / tokens, "r") as f:
+ tokens = [line.strip() for line in f]
+
+ if lexicon is not None:
+ lexicon = data_dir / lexicon
+ lines = parse_lines(lines, lexicon)
+
+ tokens_to_idx = {t: e for e, t in enumerate(tokens)}
+
+ ngram = len(prune)
+
+ logger.info("Counting data...")
+ ngrams = count_ngrams(lines, ngram, tokens_to_idx)
+
+ pruned_ngrams = prune_ngrams(ngrams, prune)
+
+ for n in range(ngram):
+ logger.info(f"Kept {len(pruned_ngrams[n])} of {len(ngrams[n])} {n + 1}-grams")
+
+ if blank == "none":
+ pruned_ngrams = add_blank_grams(pruned_ngrams, len(tokens_to_idx), blank)
+
+ if self_loops:
+ pruned_ngrams = add_self_loops(pruned_ngrams)
+
+ logger.info("Building graph from pruned ngrams...")
+ graph = build_graph(pruned_ngrams, disable_backoff)
+ logger.info(f"Graph has {graph.num_arcs()} arcs and {graph.num_nodes()} nodes.")
+
+ save_path = str(data_dir / save_path)
+
+ logger.info(f"Saving graph to {save_path}")
+ gtn.save(save_path, graph)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py
new file mode 100644
index 0000000..f605920
--- /dev/null
+++ b/src/tasks/make_wordpieces.py
@@ -0,0 +1,114 @@
+"""Creates word pieces from a text file.
+
+Most code stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py
+
+"""
+import io
+from pathlib import Path
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import sentencepiece as spm
+
+from text_recognizer.datasets.iam_preprocessor import load_metadata
+
+
+def iamdb_pieces(
+ data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
+) -> None:
+ """Creates word pieces from the iamdb train text."""
+ # Load training text.
+ with open(data_dir / text_file, "r") as f:
+ text = [line.strip() for line in f]
+
+ sp = train_spm_model(
+ iter(text),
+ num_pieces + 1, # To account for <unk>
+ user_symbols=["/"], # added so token is in the output set
+ )
+
+ vocab = sorted(set(w for t in text for w in t.split("_") if w))
+ if "move" not in vocab:
+ raise RuntimeError("`MOVE` not in vocab")
+
+ save_pieces(sp, num_pieces, data_dir, output_prefix, vocab)
+
+
+def train_spm_model(
+ sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = ""
+) -> spm.SentencePieceProcessor:
+ """Trains the sentence piece model."""
+ model = io.BytesIO()
+ spm.SentencePieceTrainer.train(
+ sentence_iterator=sentences,
+ model_writer=model,
+ vocab_size=vocab_size,
+ bos_id=-1,
+ eos_id=-1,
+ character_coverage=1.0,
+ user_defined_symbols=user_symbols,
+ )
+ sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
+ return sp
+
+
+def save_pieces(
+ sp: spm.SentencePieceProcessor,
+ num_pieces: int,
+ data_dir: Path,
+ output_prefix: str,
+ vocab: set,
+) -> None:
+ """Saves word pieces to disk."""
+ logger.info(f"Generating word piece list of size {num_pieces}.")
+ pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
+ logger.info(f"Encoding vocabulary of size {len(vocab)}.")
+ encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
+
+ # Save pieces to file.
+ with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f:
+ f.write("\n".join(pieces))
+
+ # Save lexicon to a file.
+ with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f:
+ for v, p in zip(vocab, encoded_vocab):
+ f.write(f"{v} {' '.join(p)}\n")
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.")
+@click.option(
+ "--text_file", type=str, default=None, help="Name of sentence piece training text."
+)
+@click.option(
+ "--output_prefix",
+ type=str,
+ default="word_pieces",
+ help="Prefix name to store tokens and lexicon.",
+)
+@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.")
+def cli(
+ data_dir: Optional[str],
+ text_file: Optional[str],
+ output_prefix: Optional[str],
+ num_pieces: Optional[int],
+) -> None:
+ """CLI for training the sentence piece model."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ iamdb_pieces(data_dir, text_file, num_pieces, output_prefix)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index d8372e3..a6c1c59 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -8,6 +8,7 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
+from .iam_preprocessor import load_metadata, Preprocessor
from .transforms import AddTokens, Transpose
from .util import (
_download_raw_dataset,
@@ -29,8 +30,10 @@ __all__ = [
"EmnistMapper",
"EmnistLinesDataset",
"get_samples_by_character",
+ "load_metadata",
"IamDataset",
"IamLinesDataset",
"IamParagraphsDataset",
+ "Preprocessor",
"Transpose",
]
diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py
new file mode 100644
index 0000000..5a5136c
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_preprocessor.py
@@ -0,0 +1,196 @@
+"""Preprocessor for extracting word letters from the IAM dataset.
+
+The code is mostly stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+"""
+
+import collections
+import itertools
+from pathlib import Path
+import re
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import torch
+
+
+def load_metadata(
+ data_dir: Path, wordsep: str, use_words: bool = False
+) -> collections.defaultdict:
+ """Loads IAM metadata and returns it as a dictionary."""
+ forms = collections.defaultdict(list)
+ filename = "words.txt" if use_words else "lines.txt"
+
+ with open(data_dir / "ascii" / filename, "r") as f:
+ lines = (line.strip().split() for line in f if line[0] != "#")
+ for line in lines:
+ # Skip word segmentation errors.
+ if use_words and line[1] == "err":
+ continue
+ text = " ".join(line[8:])
+
+ # Remove garbage tokens:
+ text = text.replace("#", "")
+
+ # Swap word sep form | to wordsep
+ text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
+ form_key = "-".join(line[0].split("-")[:2])
+ line_key = "-".join(line[0].split("-")[:3])
+ box_idx = 4 - use_words
+ box = tuple(int(val) for val in line[box_idx : box_idx + 4])
+ forms[form_key].append({"key": line_key, "box": box, "text": text})
+ return forms
+
+
+class Preprocessor:
+ """A preprocessor for the IAM dataset."""
+
+ # TODO: add lower case only to when generating...
+
+ def __init__(
+ self,
+ data_dir: Union[str, Path],
+ num_features: int,
+ tokens_path: Optional[Union[str, Path]] = None,
+ lexicon_path: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ self.wordsep = "_"
+ self._use_word = use_words
+ self._prepend_wordsep = prepend_wordsep
+
+ self.data_dir = Path(data_dir)
+
+ self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
+
+ # Load the set of graphemes:
+ graphemes = set()
+ for _, form in self.forms.items():
+ for line in form:
+ graphemes.update(line["text"].lower())
+ self.graphemes = sorted(graphemes)
+
+ # Build the token-to-index and index-to-token maps.
+ if tokens_path is not None:
+ with open(tokens_path, "r") as f:
+ self.tokens = [line.strip() for line in f]
+ else:
+ self.tokens = self.graphemes
+
+ if lexicon_path is not None:
+ with open(lexicon_path, "r") as f:
+ lexicon = (line.strip().split() for line in f)
+ lexicon = {line[0]: line[1:] for line in lexicon}
+ self.lexicon = lexicon
+ else:
+ self.lexicon = None
+
+ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
+ self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
+ self.num_features = num_features
+ self.text = []
+
+ @property
+ def num_tokens(self) -> int:
+ """Returns the number or tokens."""
+ return len(self.tokens)
+
+ @property
+ def use_words(self) -> bool:
+ """If words are used."""
+ return self._use_word
+
+ def extract_train_text(self) -> None:
+ """Extracts training text."""
+ keys = []
+ with open(self.data_dir / "task" / "trainset.txt") as f:
+ keys.extend((line.strip() for line in f))
+
+ for _, examples in self.forms.items():
+ for example in examples:
+ if example["key"] not in keys:
+ continue
+ self.text.append(example["text"].lower())
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ token_to_index = self.graphemes_to_index
+ if self.lexicon is not None:
+ if len(line) > 0:
+ # If the word is not found in the lexicon, fall back to letters.
+ line = [
+ t
+ for w in line.split(self.wordsep)
+ for t in self.lexicon.get(w, self.wordsep + w)
+ ]
+ token_to_index = self.tokens_to_index
+ if self._prepend_wordsep:
+ line = itertools.chain([self.wordsep], line)
+ return torch.LongTensor([token_to_index[t] for t in line])
+
+ def to_text(self, indices: List[int]) -> str:
+ """Converts indices to text."""
+ # Roughly the inverse of `to_index`
+ encoding = self.graphemes
+ if self.lexicon is not None:
+ encoding = self.tokens
+ return self._post_process(encoding[i] for i in indices)
+
+ def tokens_to_text(self, indices: List[int]) -> str:
+ """Converts tokens to text."""
+ return self._post_process(self.tokens[i] for i in indices)
+
+ def _post_process(self, indices: List[int]) -> str:
+ """A list join."""
+ return "".join(indices).strip(self.wordsep)
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
+@click.option(
+ "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
+)
+@click.option(
+ "--save_text", type=str, default=None, help="Path to save parsed train text"
+)
+@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
+def cli(
+ data_dir: Optional[str],
+ use_words: bool,
+ save_text: Optional[str],
+ save_tokens: Optional[str],
+) -> None:
+ """CLI for extracting text data from the iam dataset."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
+ preprocessor.extract_train_text()
+
+ processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
+ logger.debug(f"Saving processed files at: {processed_dir}")
+
+ if save_text is not None:
+ logger.info("Saving training text")
+ with open(processed_dir / save_text, "w") as f:
+ f.write("\n".join(t for t in preprocessor.text))
+
+ if save_tokens is not None:
+ logger.info("Saving tokens")
+ with open(processed_dir / save_tokens, "w") as f:
+ f.write("\n".join(preprocessor.tokens))
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 8956b01..60987e0 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -1,14 +1,57 @@
"""Transforms for PyTorch datasets."""
+import random
+
import numpy as np
from PIL import Image
import torch
from torch import Tensor
import torch.nn.functional as F
-from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor
+from torchvision import transforms
+from torchvision.transforms import (
+ ColorJitter,
+ Compose,
+ Normalize,
+ RandomAffine,
+ RandomHorizontalFlip,
+ RandomRotation,
+ ToPILImage,
+ ToTensor,
+)
from text_recognizer.datasets.util import EmnistMapper
+class RandomResizeCrop:
+ """Image transform with random resize and crop applied.
+
+ Stolen from
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+ """
+
+ def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
+ self.jitter = jitter
+ self.ratio = ratio
+
+ def __call__(self, img: np.ndarray) -> np.ndarray:
+ """Applies random crop and rotation to an image."""
+ w, h = img.size
+
+ # pad with white:
+ img = transforms.functional.pad(img, self.jitter, fill=255)
+
+ # crop at random (x, y):
+ x = self.jitter + random.randint(-self.jitter, self.jitter)
+ y = self.jitter + random.randint(-self.jitter, self.jitter)
+
+ # randomize aspect ratio:
+ size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
+ size = (h, int(size_w))
+ img = transforms.functional.resized_crop(img, y, x, h, w, size)
+ return img
+
+
class Transpose:
"""Transposes the EMNIST image to the correct orientation."""
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index eb5dbce..7647d7e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -5,6 +5,7 @@ from .crnn_model import CRNNModel
from .ctc_transformer_model import CTCTransformerModel
from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
+from .vqvae_model import VQVAEModel
__all__ = [
"CharacterModel",
@@ -13,4 +14,5 @@ __all__ = [
"Model",
"SegmentationModel",
"TransformerModel",
+ "VQVAEModel",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index f2cd4b8..70f4cdb 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -332,7 +332,7 @@ class Model(ABC):
def summary(
self,
input_shape: Optional[Union[List, Tuple]] = None,
- depth: int = 4,
+ depth: int = 3,
device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 12e497f..3f63053 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -6,9 +6,9 @@ import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
from text_recognizer.datasets import EmnistMapper
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
@@ -60,13 +60,19 @@ class TransformerModel(Model):
eos_token=self.eos_token,
lower=self.lower,
)
- self.tensor_transform = ToTensor()
-
+ self.tensor_transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
+ )
self.softmax = nn.Softmax(dim=2)
@torch.no_grad()
def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
src = self.network.extract_image_features(image)
+
+ # Added for vqvae transformer.
+ if isinstance(src, Tuple):
+ src = src[0]
+
memory = self.network.encoder(src)
confidence_of_predictions = []
diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py
new file mode 100644
index 0000000..70f6f1f
--- /dev/null
+++ b/src/text_recognizer/models/vqvae_model.py
@@ -0,0 +1,80 @@
+"""Defines the VQVAEModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class VQVAEModel(Model):
+ """Model for reconstructing images from codebook."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """Reconstruction of image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ image_reconstructed, _ = self.forward(image)
+
+ return image_reconstructed
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 2b624bb..bac5d28 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,4 +1,5 @@
"""Network modules."""
+from .cnn import CNN
from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
@@ -7,15 +8,19 @@ from .lenet import LeNet
from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .transducer import TDS2d
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
from .vit import ViT
+from .vq_transformer import VQTransformer
+from .vqvae import VQVAE
from .wide_resnet import WideResidualNetwork
__all__ = [
"accuracy",
"cer",
+ "CNN",
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
"DenseNet",
@@ -27,8 +32,11 @@ __all__ = [
"ResidualNetworkEncoder",
"sliding_window",
"UNet",
+ "TDS2d",
"Transformer",
"ViT",
+ "VQTransformer",
+ "VQVAE",
"wer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py
new file mode 100644
index 0000000..1807bb9
--- /dev/null
+++ b/src/text_recognizer/networks/cnn.py
@@ -0,0 +1,101 @@
+"""Implementation of a simple backbone cnn network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class CNN(nn.Module):
+ """LeNet network for character prediction."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...] = (1, 32, 64, 128),
+ kernel_sizes: Tuple[int, ...] = (4, 4, 4),
+ strides: Tuple[int, ...] = (2, 2, 2),
+ max_pool_kernel: int = 2,
+ dropout_rate: float = 0.2,
+ activation: Optional[str] = "relu",
+ ) -> None:
+ """Initialization of the LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+ strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
+ max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
+ dropout_rate (float): The dropout rate. Defaults to 0.2.
+ activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+ Raises:
+ RuntimeError: if the number of hyperparameters does not match in length.
+
+ """
+ super().__init__()
+
+ if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
+ raise RuntimeError("The number of the hyperparameters does not match.")
+
+ self.cnn = self._build_network(
+ channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
+ )
+
+ def _build_network(
+ self,
+ channels: Tuple[int, ...],
+ kernel_sizes: Tuple[int, ...],
+ strides: Tuple[int, ...],
+ max_pool_kernel: int,
+ dropout_rate: float,
+ activation: str,
+ ) -> nn.Sequential:
+ # Load activation function.
+ activation_fn = activation_function(activation)
+
+ channels = list(channels)
+ in_channels = channels.pop(0)
+ configuration = zip(channels, kernel_sizes, strides)
+
+ modules = nn.ModuleList([])
+
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ # Add max pool to reduce output size.
+ if i == len(channels) // 2:
+ modules.append(nn.MaxPool2d(max_pool_kernel))
+ if i == 0:
+ modules.append(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ )
+ )
+ else:
+ modules.append(
+ nn.Sequential(
+ activation_fn,
+ nn.BatchNorm2d(in_channels),
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ )
+ )
+
+ if dropout_rate:
+ modules.append(nn.Dropout2d(p=dropout_rate))
+
+ in_channels = out_channels
+
+ return nn.Sequential(*modules)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.cnn(x)
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 43e5403..7133c26 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -29,14 +29,22 @@ class CNNTransformer(nn.Module):
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
+ pool_kernel: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
+
+ if pool_kernel is not None:
+ self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
+ else:
+ self.max_pool = None
+
self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.pos_dropout = nn.Dropout(p=dropout_rate)
self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
nn.init.normal_(self.character_embedding.weight, std=0.02)
@@ -98,18 +106,23 @@ class CNNTransformer(nn.Module):
# If batch dimension is missing, it needs to be added.
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
+
src = self.backbone(src)
+ if self.max_pool is not None:
+ src = self.max_pool(src)
+
if self.adaptive_pool is not None:
src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
src = src.squeeze(3)
else:
- src = rearrange(src, "b c h w -> b (w h) c")
+ src = rearrange(src, "b c h w -> b (h w) c")
b, t, _ = src.shape
src += self.src_position_embedding[:, :t]
+ src = self.pos_dropout(src)
return src
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index ffad792..2605731 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -1,4 +1,7 @@
"""Utility functions for models."""
+from typing import Optional
+
+from einops import rearrange
import Levenshtein as Lev
import torch
from torch import Tensor
@@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
return acc
-def cer(outputs: Tensor, targets: Tensor) -> float:
+def cer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the character error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (Optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The cer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
@@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float:
return lev_dist / len(decoded_predictions)
-def wer(outputs: Tensor, targets: Tensor) -> float:
+def wer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the Word error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The wer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
new file mode 100644
index 0000000..fdd6662
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/__init__.py
@@ -0,0 +1,2 @@
+"""Transducer modules."""
+from .tds_conv import TDS2d
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
new file mode 100644
index 0000000..018caf2
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/tds_conv.py
@@ -0,0 +1,205 @@
+"""Time-Depth Separable Convolutions.
+
+References:
+ https://arxiv.org/abs/1904.02619
+ https://arxiv.org/pdf/2010.01003.pdf
+
+Code stolen from:
+ https://github.com/facebookresearch/gtn_applications
+
+
+"""
+from typing import List, Tuple
+
+from einops import rearrange
+import gtn
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class TDSBlock2d(nn.Module):
+ """Internal block of a 2D TDSC network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ img_depth: int,
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.img_depth = img_depth
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+ self.fc_dim = in_channels * img_depth
+
+ # Network placeholders.
+ self.conv = None
+ self.mlp = None
+ self.instance_norm = None
+
+ self._build_block()
+
+ def _build_block(self) -> None:
+ # Convolutional block.
+ self.conv = nn.Sequential(
+ nn.Conv3d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
+ padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # MLP block.
+ self.mlp = nn.Sequential(
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # Instance norm.
+ self.instance_norm = nn.ModuleList(
+ [
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, CD, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, CD, H, W = x.shape
+ C, D = self.in_channels, self.img_depth
+ residual = x
+ x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
+ x = self.conv(x)
+ x = rearrange(x, "b c d h w -> b (c d) h w")
+ x += residual
+
+ x = self.instance_norm[0](x)
+
+ x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
+ x + self.instance_norm[1](x)
+
+ # Output shape: [B, CD, H, W]
+ return x
+
+
+class TDS2d(nn.Module):
+ """TDS Netowrk.
+
+ Structure is the following:
+ Downsample layer -> TDS2d group -> ... -> Linear output layer
+
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ depth: int,
+ tds_groups: Tuple[int],
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ in_channels: int = 1,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.depth = depth
+ self.tds_groups = tds_groups
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+
+ self.tds = None
+ self.fc = None
+
+ def _build_network(self) -> None:
+
+ modules = []
+ stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
+ if self.input_dim % stride_h:
+ raise RuntimeError(
+ f"Image height not divisible by total stride {stride_h}."
+ )
+
+ for tds_group in self.tds_groups:
+ # Add downsample layer.
+ out_channels = self.depth * tds_group["channels"]
+ modules.extend(
+ [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ kernel_size=self.kernel_size,
+ padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ stride=tds_group["stride"],
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.InstanceNorm2d(out_channels, affine=True),
+ ]
+ )
+
+ for _ in range(tds_group["num_blocks"]):
+ modules.append(
+ TDSBlock2d(
+ tds_group["channels"],
+ self.depth,
+ self.kernel_size,
+ self.dropout_rate,
+ )
+ )
+
+ self.in_channels = out_channels
+
+ self.tds = nn.Sequential(*modules)
+ self.fc = nn.Linear(
+ self.in_channels * self.input_dim // stride_h, self.output_dim
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, H, W = x.shape
+ x = rearrange(
+ x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
+ )
+ x = self.tds(x)
+
+ # x shape: [B, C, H, W]
+ x = rearrange(x, "b c h w -> b w (c h)")
+
+ return self.fc(x)
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index 711a952..131a6b4 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -65,13 +65,18 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
network_args = state_dict["network_args"]
weights = state_dict["model_state"]
+ freeze = False
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ backbone_args.pop("freeze")
+ freeze = True
+ network_args = backbone_args
+
# Initializes the network with trained weights.
backbone = backbone_(**network_args)
backbone.load_state_dict(weights)
- if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ if freeze:
for params in backbone.parameters():
params.requires_grad = False
-
else:
backbone_ = getattr(network_module, backbone)
backbone = backbone_(**backbone_args)
diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..c673d96
--- /dev/null
+++ b/src/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,150 @@
+"""A VQ-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class VQTransformer(nn.Module):
+ """VQ+Transfomer for image to character sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ adaptive_pool_dim: Tuple,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ # Configure vector quantized backbone.
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.conv = nn.Sequential(
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
+ nn.ReLU(inplace=True),
+ )
+
+ # Configure embeddings for Transformer network.
+ self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ )
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: The input src to the transformer and the vq loss.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src, vq_loss = self.backbone.encode(src)
+ # src = self.backbone.decoder.res_block(src)
+ src = self.conv(src)
+
+ if self.adaptive_pool is not None:
+ src = rearrange(src, "b c h w -> b w c h")
+ src = self.adaptive_pool(src)
+ src = src.squeeze(3)
+ else:
+ src = rearrange(src, "b c h w -> b (w h) c")
+
+ b, t, _ = src.shape
+
+ src += self.src_position_embedding[:, :t]
+
+ return src, vq_loss
+
+ def target_embedding(self, trg: Tensor) -> Tensor:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tensor: Encoded target tensor.
+
+ """
+ trg = self.character_embedding(trg.long())
+ trg = self.trg_position_encoding(trg)
+ return trg
+
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ image_features, vq_loss = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
+ return logits, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
index e1f05fa..763953c 100644
--- a/src/text_recognizer/networks/vqvae/__init__.py
+++ b/src/text_recognizer/networks/vqvae/__init__.py
@@ -1 +1,5 @@
"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ upsampling: Optional[List[List[int]]] = None,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.upsampling = upsampling
+
+ self.res_block = nn.ModuleList([])
+ self.upsampling_block = nn.ModuleList([])
+
+ self.embedding_dim = embedding_dim
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.decoder = self._build_decoder(
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ )
+
+ def _build_decompression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ activation,
+ )
+ )
+
+ if i < len(self.upsampling):
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ modules.extend(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+ ),
+ nn.Tanh(),
+ )
+ )
+
+ return modules
+
+ def _build_decoder(
+ self,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+
+ self.res_block.append(
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ )
+
+ # Bottleneck module.
+ self.res_block.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[0], channels[0], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ # Decompression module
+ self.upsampling_block.extend(
+ self._build_decompression_block(
+ channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ self.res_block = nn.Sequential(*self.res_block)
+ self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+ return nn.Sequential(self.res_block, self.upsampling_block)
+
+ def forward(self, z_q: Tensor) -> Tensor:
+ """Reconstruct input from given codes."""
+ x_reconstruction = self.decoder(z_q)
+ return x_reconstruction
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
index 60c4c43..d3adac5 100644
--- a/src/text_recognizer/networks/vqvae/encoder.py
+++ b/src/text_recognizer/networks/vqvae/encoder.py
@@ -1,6 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-
-from typing import List, Optional, Type
+from typing import List, Optional, Tuple, Type
import torch
from torch import nn
@@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- activation,
+ nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
]
@@ -42,23 +37,111 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
beta: float = 0.25,
- activation: str = "elu",
+ activation: str = "leaky_relu",
dropout_rate: float = 0.0,
) -> None:
super().__init__()
- pass
- # if dropout_rate:
- # if activation == "selu":
- # dropout = nn.AlphaDropout(p=dropout_rate)
- # else:
- # dropout = nn.Dropout(p=dropout_rate)
- # else:
- # dropout = None
-
- def _build_encoder(self) -> nn.Sequential:
- # TODO: Continue to implement encoder.
- pass
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.beta = beta
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.encoder = self._build_encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
+ )
+
+ # Configure Vector Quantizer.
+ self.vector_quantizer = VectorQuantizer(
+ self.num_embeddings, self.embedding_dim, self.beta
+ )
+
+ def _build_compression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for out_channels, kernel_size, stride in configuration:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ ),
+ activation,
+ )
+ )
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ return modules
+
+ def _build_encoder(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+ encoder = nn.ModuleList([])
+
+ # compression module
+ encoder.extend(
+ self._build_compression_block(
+ in_channels, channels, kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ # Bottleneck module.
+ encoder.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[-1], channels[-1], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ encoder.append(
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ )
+
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input into a discrete representation."""
+ z_e = self.encoder(x)
+ z_q, vq_loss = self.vector_quantizer(z_e)
+ return z_q, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
index 25e5583..f92c7ee 100644
--- a/src/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module):
self.embedding = nn.Embedding(self.K, self.D)
# Initialize the codebook.
- self.embedding.weight.uniform_(-1 / self.K, 1 / self.K)
+ nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
def discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+ """Vector Quantized Variational AutoEncoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ upsampling: Optional[List[List[int]]] = None,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ # configure encoder.
+ self.encoder = Encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ num_embeddings,
+ beta,
+ activation,
+ dropout_rate,
+ )
+
+ # Configure decoder.
+ channels.reverse()
+ kernel_sizes.reverse()
+ strides.reverse()
+ self.decoder = Decoder(
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ upsampling,
+ activation,
+ dropout_rate,
+ )
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input to a latent code."""
+ return self.encoder(x)
+
+ def decode(self, z_q: Tensor) -> Tensor:
+ """Reconstructs input from latent codes."""
+ return self.decoder(z_q)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Compresses and decompresses input."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ z_q, vq_loss = self.encode(x)
+ x_reconstruction = self.decode(z_q)
+ return x_reconstruction, vq_loss
diff --git a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
new file mode 100644
index 0000000..b5295c2
--- /dev/null
+++ b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
Binary files differ
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 2c9a196..faafea6 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -296,7 +296,12 @@ def run_experiment(
# Run inference over test set.
if test:
logger.info("Loading checkpoint with the best weights.")
- model.load_from_checkpoint(model_dir / "best.pt")
+ if "checkpoint" in experiment_config["train_args"]:
+ model.load_from_checkpoint(
+ model_dir / experiment_config["train_args"]["checkpoint"]
+ )
+ else:
+ model.load_from_checkpoint(model_dir / "best.pt")
logger.info("Running inference on test set.")
if experiment_config["criterion"]["type"] == "EmbeddingLoss":
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index 95ec142..80c4177 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -7,7 +7,12 @@ from .lr_schedulers import (
SWA,
)
from .progress_bar import ProgressBar
-from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger
+from .wandb_callbacks import (
+ WandbCallback,
+ WandbImageLogger,
+ WandbReconstructionLogger,
+ WandbSegmentationLogger,
+)
__all__ = [
"Callback",
@@ -17,6 +22,7 @@ __all__ = [
"LRScheduler",
"WandbCallback",
"WandbImageLogger",
+ "WandbReconstructionLogger",
"WandbSegmentationLogger",
"ProgressBar",
"SWA",
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 20414df..552a4f4 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -201,3 +201,61 @@ class WandbSegmentationLogger(Callback):
)
wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+ """Custom W&B callback for image reconstructions logging."""
+
+ def __init__(
+ self, example_indices: Optional[List] = None, num_examples: int = 4,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Reconstructions Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Reconstructions Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for image in self.images:
+ reconstructed_image = (
+ self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+ )
+ images.append(image)
+ images.append(reconstructed_image)
+
+ wandb.log(
+ {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+ )
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index 40a25da..b770c94 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -12,7 +12,7 @@ import torch
from torch import Tensor
from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
-from training.trainer.util import log_val_metric, RunningAverage
+from training.trainer.util import log_val_metric
import wandb
from text_recognizer.models import Model
@@ -30,8 +30,6 @@ warnings.filterwarnings("ignore")
class Trainer:
"""Trainer for training PyTorch models."""
- # TODO: proper add teardown?
-
def __init__(
self,
max_epochs: int,
@@ -46,7 +44,7 @@ class Trainer:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
- max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
+ max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
Transformers. Default is None.
@@ -79,35 +77,32 @@ class Trainer:
self.callbacks = CallbackList(self.model, self.callbacks)
def compute_metrics(
- self,
- output: Tensor,
- targets: Tensor,
- loss: Tensor,
- loss_avg: Type[RunningAverage],
+ self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
) -> Dict:
"""Computes metrics for output and target pairs."""
# Compute metrics.
loss = loss.detach().float().item()
- loss_avg.update(loss)
output = output.detach()
targets = targets.detach()
if self.model.metrics is not None:
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
+ metrics = {}
+ for metric in self.model.metrics:
+ if metric == "cer" or metric == "wer":
+ metrics[metric] = self.model.metrics[metric](
+ output,
+ targets,
+ batch_size,
+ self.model.mapper(self.model.pad_token),
+ )
+ else:
+ metrics[metric] = self.model.metrics[metric](output, targets)
else:
metrics = {}
metrics["loss"] = loss
return metrics
- def training_step(
- self,
- batch: int,
- samples: Tuple[Tensor, Tensor],
- loss_avg: Type[RunningAverage],
- ) -> Dict:
+ def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
"""Performs the training step."""
# Pass the tensor to the device for computation.
data, targets = samples
@@ -116,25 +111,43 @@ class Trainer:
targets.to(self.model.device),
)
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
# Forward pass.
# Get the network prediction.
if self.transformer_model:
if self.freeze_backbone is not None and batch < self.freeze_backbone:
with torch.no_grad():
image_features = self.model.network.extract_image_features(data)
+
+ if isinstance(image_features, Tuple):
+ image_features, _ = image_features
+
output = self.model.network.decode_image_features(
image_features, targets[:, :-1]
)
else:
output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else:
output = self.model.forward(data)
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
# Compute the loss.
loss = self.model.criterion(output, targets)
+ if aux_loss is not None:
+ loss += aux_loss
+
# Backward pass.
# Clear the previous gradients.
for p in self.model.network.parameters():
@@ -151,7 +164,7 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- metrics = self.compute_metrics(output, targets, loss, loss_avg)
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
return metrics
@@ -160,22 +173,15 @@ class Trainer:
# Set model to traning mode.
self.model.train()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
for batch, samples in enumerate(self.model.train_dataloader()):
self.callbacks.on_train_batch_begin(batch)
- metrics = self.training_step(batch, samples, loss_avg)
+ metrics = self.training_step(batch, samples)
self.callbacks.on_train_batch_end(batch, logs=metrics)
@torch.no_grad()
- def validation_step(
- self,
- batch: int,
- samples: Tuple[Tensor, Tensor],
- loss_avg: Type[RunningAverage],
- ) -> Dict:
+ def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
"""Performs the validation step."""
+
# Pass the tensor to the device for computation.
data, targets = samples
data, targets = (
@@ -183,21 +189,35 @@ class Trainer:
targets.to(self.model.device),
)
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
if self.transformer_model:
output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else:
output = self.model.forward(data)
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
# Compute the loss.
loss = self.model.criterion(output, targets)
+ if aux_loss is not None:
+ loss += aux_loss
+
# Compute metrics.
- metrics = self.compute_metrics(output, targets, loss, loss_avg)
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
return metrics
@@ -206,15 +226,12 @@ class Trainer:
# Set model to eval mode.
self.model.eval()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
# Summary for the current eval loop.
summary = []
for batch, samples in enumerate(self.model.val_dataloader()):
self.callbacks.on_validation_batch_begin(batch)
- metrics = self.validation_step(batch, samples, loss_avg)
+ metrics = self.validation_step(batch, samples)
self.callbacks.on_validation_batch_end(batch, logs=metrics)
summary.append(metrics)
@@ -287,14 +304,11 @@ class Trainer:
# Check if SWA network is available.
self.model.use_swa_model()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
# Summary for the current test loop.
summary = []
for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples, loss_avg)
+ metrics = self.validation_step(batch, samples)
summary.append(metrics)
self.callbacks.on_test_end()