summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/EMNIST/raw/metadata.toml3
-rw-r--r--noxfile.py9
-rw-r--r--poetry.lock107
-rw-r--r--pyproject.toml8
-rw-r--r--tests/__init__.py (renamed from text_recognizer/tests/__init__.py)0
-rw-r--r--tests/support/__init__.py (renamed from text_recognizer/tests/support/__init__.py)0
-rw-r--r--tests/support/create_emnist_lines_support_files.py (renamed from text_recognizer/tests/support/create_emnist_lines_support_files.py)0
-rw-r--r--tests/support/create_emnist_support_files.py (renamed from text_recognizer/tests/support/create_emnist_support_files.py)0
-rw-r--r--tests/support/create_iam_lines_support_files.py (renamed from text_recognizer/tests/support/create_iam_lines_support_files.py)0
-rw-r--r--tests/support/emnist_lines/Knox Ky<eos>.png (renamed from text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png)bin2301 -> 2301 bytes
-rw-r--r--tests/support/emnist_lines/ancillary beliefs and<eos>.png (renamed from text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png)bin5424 -> 5424 bytes
-rw-r--r--tests/support/emnist_lines/they<eos>.png (renamed from text_recognizer/tests/support/emnist_lines/they<eos>.png)bin1391 -> 1391 bytes
-rw-r--r--tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png (renamed from text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png)bin5170 -> 5170 bytes
-rw-r--r--tests/support/iam_lines/and came into the livingroom, where<eos>.png (renamed from text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png)bin3617 -> 3617 bytes
-rw-r--r--tests/support/iam_lines/his entrance. He came, almost falling<eos>.png (renamed from text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png)bin3923 -> 3923 bytes
-rw-r--r--tests/support/iam_paragraphs/a01-000u.jpg (renamed from text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg)bin14890 -> 14890 bytes
-rw-r--r--tests/test_character_predictor.py (renamed from text_recognizer/tests/test_character_predictor.py)0
-rw-r--r--tests/test_line_predictor.py (renamed from text_recognizer/tests/test_line_predictor.py)0
-rw-r--r--tests/test_paragraph_text_recognizer.py (renamed from text_recognizer/tests/test_paragraph_text_recognizer.py)0
-rw-r--r--text_recognizer/datasets/__init__.py38
-rw-r--r--text_recognizer/datasets/base_data_module.py36
-rw-r--r--text_recognizer/datasets/base_dataset.py21
-rw-r--r--text_recognizer/datasets/dataset.py152
-rw-r--r--text_recognizer/datasets/download_utils.py6
-rw-r--r--text_recognizer/datasets/emnist.py88
-rw-r--r--text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--text_recognizer/datasets/emnist_lines.py184
27 files changed, 345 insertions, 308 deletions
diff --git a/data/EMNIST/raw/metadata.toml b/data/EMNIST/raw/metadata.toml
deleted file mode 100644
index 10304ce..0000000
--- a/data/EMNIST/raw/metadata.toml
+++ /dev/null
@@ -1,3 +0,0 @@
-filename = 'gzip.zip'
-md5 = '58c8d27c78d21e728a6bc7b3cc06412e'
-url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
diff --git a/noxfile.py b/noxfile.py
index 098a551..0e2ac7b 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -8,7 +8,14 @@ from nox.sessions import Session
package = "text-recognizer"
nox.options.sessions = "lint", "mypy", "pytype", "safety", "tests"
-locations = "src", "tests", "noxfile.py", "docs/conf.py", "src/text_recognizer/tests"
+locations = (
+ "text_recognizer",
+ "training",
+ "tasks",
+ "tests",
+ "noxfile.py",
+ "docs/conf.py",
+)
def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> None:
diff --git a/poetry.lock b/poetry.lock
index a389e98..094043d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1730,14 +1730,14 @@ alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"]
[[package]]
name = "scipy"
-version = "1.5.4"
+version = "1.6.1"
description = "SciPy: Scientific Library for Python"
category = "main"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.7"
[package.dependencies]
-numpy = ">=1.14.5"
+numpy = ">=1.16.5"
[[package]]
name = "send2trash"
@@ -2033,7 +2033,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "torch"
-version = "1.7.1"
+version = "1.8.0"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
category = "main"
optional = false
@@ -2053,7 +2053,7 @@ python-versions = ">=3.5"
[[package]]
name = "torchvision"
-version = "0.8.2"
+version = "0.9.0"
description = "image and video datasets and models for torch deep learning"
category = "main"
optional = false
@@ -2062,7 +2062,7 @@ python-versions = "*"
[package.dependencies]
numpy = "*"
pillow = ">=4.1.1"
-torch = "1.7.1"
+torch = "1.8.0"
[package.extras]
scipy = ["scipy"]
@@ -2279,7 +2279,7 @@ multidict = ">=4.0"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
-content-hash = "45c6282e27a0231e3dc3c951a9540f3cd2f6d6eb3d3eda1eab84168014d8062a"
+content-hash = "82834b5e12f38b06a306f1b8c05ed6ee33d671c747de40615d2bd36f3f138279"
[metadata.files]
absl-py = [
@@ -3385,31 +3385,25 @@ scikit-learn = [
{file = "scikit_learn-0.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea"},
]
scipy = [
- {file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"},
- {file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"},
- {file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"},
- {file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"},
- {file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"},
- {file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"},
- {file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"},
- {file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"},
- {file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"},
- {file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"},
- {file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"},
- {file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"},
- {file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"},
- {file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"},
- {file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"},
- {file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"},
- {file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"},
- {file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"},
- {file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"},
- {file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"},
- {file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"},
- {file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"},
- {file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"},
- {file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"},
- {file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"},
+ {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"},
+ {file = "scipy-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:e79570979ccdc3d165456dd62041d9556fb9733b86b4b6d818af7a0afc15f092"},
+ {file = "scipy-1.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:a423533c55fec61456dedee7b6ee7dce0bb6bfa395424ea374d25afa262be261"},
+ {file = "scipy-1.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:33d6b7df40d197bdd3049d64e8e680227151673465e5d85723b3b8f6b15a6ced"},
+ {file = "scipy-1.6.1-cp37-cp37m-win32.whl", hash = "sha256:6725e3fbb47da428794f243864f2297462e9ee448297c93ed1dcbc44335feb78"},
+ {file = "scipy-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:5fa9c6530b1661f1370bcd332a1e62ca7881785cc0f80c0d559b636567fab63c"},
+ {file = "scipy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bd50daf727f7c195e26f27467c85ce653d41df4358a25b32434a50d8870fc519"},
+ {file = "scipy-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f46dd15335e8a320b0fb4685f58b7471702234cba8bb3442b69a3e1dc329c345"},
+ {file = "scipy-1.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0e5b0ccf63155d90da576edd2768b66fb276446c371b73841e3503be1d63fb5d"},
+ {file = "scipy-1.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2481efbb3740977e3c831edfd0bd9867be26387cacf24eb5e366a6a374d3d00d"},
+ {file = "scipy-1.6.1-cp38-cp38-win32.whl", hash = "sha256:68cb4c424112cd4be886b4d979c5497fba190714085f46b8ae67a5e4416c32b4"},
+ {file = "scipy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:5f331eeed0297232d2e6eea51b54e8278ed8bb10b099f69c44e2558c090d06bf"},
+ {file = "scipy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8a51d33556bf70367452d4d601d1742c0e806cd0194785914daf19775f0e67"},
+ {file = "scipy-1.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:83bf7c16245c15bc58ee76c5418e46ea1811edcc2e2b03041b804e46084ab627"},
+ {file = "scipy-1.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:794e768cc5f779736593046c9714e0f3a5940bc6dcc1dba885ad64cbfb28e9f0"},
+ {file = "scipy-1.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5da5471aed911fe7e52b86bf9ea32fb55ae93e2f0fac66c32e58897cfb02fa07"},
+ {file = "scipy-1.6.1-cp39-cp39-win32.whl", hash = "sha256:8e403a337749ed40af60e537cc4d4c03febddcc56cd26e774c9b1b600a70d3e4"},
+ {file = "scipy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a5193a098ae9f29af283dcf0041f762601faf2e595c0db1da929875b7570353f"},
+ {file = "scipy-1.6.1.tar.gz", hash = "sha256:c4fceb864890b6168e79b0e714c585dbe2fd4222768ee90bc1aa0f8218691b11"},
]
send2trash = [
{file = "Send2Trash-1.5.0-py3-none-any.whl", hash = "sha256:f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b"},
@@ -3544,32 +3538,41 @@ toml = [
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
]
torch = [
- {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"},
- {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"},
- {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"},
- {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"},
- {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"},
- {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"},
- {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"},
- {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"},
- {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"},
- {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"},
- {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"},
- {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"},
+ {file = "torch-1.8.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:78b84115fd03f4587382a38b0da98cdd1827117806c80ebf97843a64213816cc"},
+ {file = "torch-1.8.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:86f13f324fd87870bd0d37864f4f5814dc27f9e7ed9ea222f1cc7d7dc01a8ffe"},
+ {file = "torch-1.8.0-cp36-cp36m-win_amd64.whl", hash = "sha256:394a99d777e487e773e0172cb0a0bce5b411e3090d89844e8dd55618be9bc970"},
+ {file = "torch-1.8.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:229a8dc38059ef6c7171f3f4f49c51e8a3d9644ce6c32dcddd9f1bac888a78aa"},
+ {file = "torch-1.8.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:6ecdbd4494b4bf2d31a24ddfbdff32bd995389bc8662a454bd40d3e8ce202907"},
+ {file = "torch-1.8.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:08aff0383e868f1e9882b732bbe6934defab690ad1745a03d5f1a150a4e1aeba"},
+ {file = "torch-1.8.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c87c7b0fd31c331968674cb73e82396a622b06a8e20425584922b767f2ffb259"},
+ {file = "torch-1.8.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:b9d6c8c457b90b5167f3ab0bd1ff7193a06935533176bc6d41e1763d353e9740"},
+ {file = "torch-1.8.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:fa1e391cca3937d5dea31f31a1a80a01bd4a8062c039448c254bbf5a58eb0787"},
+ {file = "torch-1.8.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:affef9bce6eed232308dd89d55d3a37a105f35460f4705375980d27154c51e24"},
+ {file = "torch-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:287a7677df844bf2c4425698fd6d9434065211219cd7fd96000ed981c4d92288"},
+ {file = "torch-1.8.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1b58f70c150e066bcd7401a3bdfad661a04244817a5dac9990b5367523887d3f"},
+ {file = "torch-1.8.0-cp38-none-macosx_11_1_arm64.whl", hash = "sha256:923856c2e6e53d5a747d83ff40faadd791d27cea2fd881b8d6990ea269f47572"},
+ {file = "torch-1.8.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:2318fac860ae73dc6486c0de2223674d9ef6139fc75f157af2bf8dce4fca5524"},
+ {file = "torch-1.8.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:05b114cb793816cd140794d5874f463972cb639f3b55d3a060f21fd066f5b629"},
+ {file = "torch-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:7438431e03af793979cb1a9b5dd9f399b38461748e9f21f60e36149ee215d751"},
+ {file = "torch-1.8.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:d98d167994d2e30df61a98eaca1684c50761f096d7f76c0c99789ac8cea50b55"},
]
torch-summary = [
{file = "torch-summary-1.4.3.tar.gz", hash = "sha256:2dcbc1dfd07dca9f4080bcacdaf90db3f2fc28efee348c8fba9033039b0e8c82"},
{file = "torch_summary-1.4.3-py3-none-any.whl", hash = "sha256:a0a76916bd11d054fd3863dc7c474971922badfbc13d6404f9eddd297041f094"},
]
torchvision = [
- {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"},
- {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"},
- {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"},
- {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"},
- {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"},
- {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"},
- {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"},
- {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"},
+ {file = "torchvision-0.9.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:63052147c776d9f93385410c1d5a791386eb0cb5e1b93c7feac686f8dbe6eb06"},
+ {file = "torchvision-0.9.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:874714f30822d4c1160071dac004d48ae641bfdccccbb497098c86f6589ec0f1"},
+ {file = "torchvision-0.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:b9f71f62725776495071b875494af86615f225b1a40902f5df452da5cfde0510"},
+ {file = "torchvision-0.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24b505dbcf3cb8da49d4b1447543c1021b699c84fc3701523101b62ee4adf097"},
+ {file = "torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:d90750ae76a0cad8ffb6b509b30412dcd102d27d5f34f7184b289b6687de580e"},
+ {file = "torchvision-0.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8791da742c24344646a4ac36adee9327491f7fff7607dffe352402b5bf25ea21"},
+ {file = "torchvision-0.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fa302f6e8fe33a8d5c6649e659655c0427eee662fe22ce69eb56fa402b520c26"},
+ {file = "torchvision-0.9.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8fce78a59959f4bb4780a78c2277d617e44da7bc270bc449ff403187f6b587fc"},
+ {file = "torchvision-0.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:2252bc63fcccb27785726dd9d0d9a97432657a5d139390bf93cd6bdf227a4401"},
+ {file = "torchvision-0.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:421bda7131f3c0eae2260f10174ac3c49e54183b33acb927b4b572f4cd90066d"},
+ {file = "torchvision-0.9.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d85d405e8cf694c1f85da7f0496ea69dd4f8d8dafbdad1e29bcdc4c621fc5cf0"},
+ {file = "torchvision-0.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:b03275f351feffaf7450d234ffb57cce26ff5e696d01ef5f543de205f18849a9"},
]
tornado = [
{file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"},
diff --git a/pyproject.toml b/pyproject.toml
index ef75edf..33d539e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,8 +22,8 @@ sphinx_rtd_theme = "^0.4.3"
boltons = "^20.1.0"
h5py = "^3.2.1"
toml = "^0.10.1"
-torch = "^1.7.0"
-torchvision = "^0.8.1"
+torch = "^1.8.0"
+torchvision = "^0.9.0"
loguru = "^0.5.0"
matplotlib = "^3.2.1"
tqdm = "^4.46.1"
@@ -64,12 +64,13 @@ gpustat = "^0.6.0"
redlock-py = "^1.0.8"
wandb = "^0.10.11"
graphviz = "^0.16"
+scipy = "^1.6.1"
[tool.coverage.report]
fail_under = 50
[tool.poetry.scripts]
-download-emnist = "text_recognizer.datasets.util:download_emnist"
+download-emnist = "text_recognizer.datasets.emnist:download_emnist"
download-iam = "text_recognizer.datasets.iam_dataset:main"
create-emnist-support-files = "text_recognizer.tests.support.create_emnist_support_files:create_emnist_support_files"
create-emnist-lines-datasets = "text_recognizer.datasets.emnist_lines_dataset:create_datasets"
@@ -78,7 +79,6 @@ prepare-experiments = "training.prepare_experiments:run_cli"
run-experiment = "training.run_experiment:run_cli"
-
[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
diff --git a/text_recognizer/tests/__init__.py b/tests/__init__.py
index 18ff212..18ff212 100644
--- a/text_recognizer/tests/__init__.py
+++ b/tests/__init__.py
diff --git a/text_recognizer/tests/support/__init__.py b/tests/support/__init__.py
index a265ede..a265ede 100644
--- a/text_recognizer/tests/support/__init__.py
+++ b/tests/support/__init__.py
diff --git a/text_recognizer/tests/support/create_emnist_lines_support_files.py b/tests/support/create_emnist_lines_support_files.py
index 9abe143..9abe143 100644
--- a/text_recognizer/tests/support/create_emnist_lines_support_files.py
+++ b/tests/support/create_emnist_lines_support_files.py
diff --git a/text_recognizer/tests/support/create_emnist_support_files.py b/tests/support/create_emnist_support_files.py
index f9ff030..f9ff030 100644
--- a/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/tests/support/create_emnist_support_files.py
diff --git a/text_recognizer/tests/support/create_iam_lines_support_files.py b/tests/support/create_iam_lines_support_files.py
index 50f9e3d..50f9e3d 100644
--- a/text_recognizer/tests/support/create_iam_lines_support_files.py
+++ b/tests/support/create_iam_lines_support_files.py
diff --git a/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/tests/support/emnist_lines/Knox Ky<eos>.png
index b7d0618..b7d0618 100644
--- a/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
+++ b/tests/support/emnist_lines/Knox Ky<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/tests/support/emnist_lines/ancillary beliefs and<eos>.png
index 14a8cf3..14a8cf3 100644
--- a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
+++ b/tests/support/emnist_lines/ancillary beliefs and<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/emnist_lines/they<eos>.png b/tests/support/emnist_lines/they<eos>.png
index 7f05951..7f05951 100644
--- a/text_recognizer/tests/support/emnist_lines/they<eos>.png
+++ b/tests/support/emnist_lines/they<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
index 6eeb642..6eeb642 100644
--- a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
+++ b/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/tests/support/iam_lines/and came into the livingroom, where<eos>.png
index 4974cf8..4974cf8 100644
--- a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
+++ b/tests/support/iam_lines/and came into the livingroom, where<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
index a731245..a731245 100644
--- a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
+++ b/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/tests/support/iam_paragraphs/a01-000u.jpg
index d9753b6..d9753b6 100644
--- a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
+++ b/tests/support/iam_paragraphs/a01-000u.jpg
Binary files differ
diff --git a/text_recognizer/tests/test_character_predictor.py b/tests/test_character_predictor.py
index 01bda78..01bda78 100644
--- a/text_recognizer/tests/test_character_predictor.py
+++ b/tests/test_character_predictor.py
diff --git a/text_recognizer/tests/test_line_predictor.py b/tests/test_line_predictor.py
index eede4d4..eede4d4 100644
--- a/text_recognizer/tests/test_line_predictor.py
+++ b/tests/test_line_predictor.py
diff --git a/text_recognizer/tests/test_paragraph_text_recognizer.py b/tests/test_paragraph_text_recognizer.py
index 3e280b9..3e280b9 100644
--- a/text_recognizer/tests/test_paragraph_text_recognizer.py
+++ b/tests/test_paragraph_text_recognizer.py
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py
index a6c1c59..2727b20 100644
--- a/text_recognizer/datasets/__init__.py
+++ b/text_recognizer/datasets/__init__.py
@@ -1,39 +1 @@
"""Dataset modules."""
-from .emnist_dataset import EmnistDataset
-from .emnist_lines_dataset import (
- construct_image_from_string,
- EmnistLinesDataset,
- get_samples_by_character,
-)
-from .iam_dataset import IamDataset
-from .iam_lines_dataset import IamLinesDataset
-from .iam_paragraphs_dataset import IamParagraphsDataset
-from .iam_preprocessor import load_metadata, Preprocessor
-from .transforms import AddTokens, Transpose
-from .util import (
- _download_raw_dataset,
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
- ESSENTIALS_FILENAME,
-)
-
-__all__ = [
- "_download_raw_dataset",
- "AddTokens",
- "compute_sha256",
- "construct_image_from_string",
- "DATA_DIRNAME",
- "download_url",
- "EmnistDataset",
- "EmnistMapper",
- "EmnistLinesDataset",
- "get_samples_by_character",
- "load_metadata",
- "IamDataset",
- "IamLinesDataset",
- "IamParagraphsDataset",
- "Preprocessor",
- "Transpose",
-]
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
index 09a0a43..830b39b 100644
--- a/text_recognizer/datasets/base_data_module.py
+++ b/text_recognizer/datasets/base_data_module.py
@@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(pl.LightningDataModule):
"""Base PyTorch Lightning DataModule."""
-
+
def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
super().__init__()
self.batch_size = batch_size
@@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule):
def config(self) -> Dict:
"""Return important settings of the dataset."""
- return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping}
+ return {
+ "input_dim": self.dims,
+ "output_dims": self.output_dims,
+ "mapping": self.mapping,
+ }
def prepare_data(self) -> None:
"""Prepare data for training."""
pass
- def setup(self, stage: Any = None) -> None:
+ def setup(self, stage: str = None) -> None:
"""Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and
@@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule):
self.data_val = None
self.data_test = None
-
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""
- return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+ return DataLoader(
+ self.data_train,
+ shuffle=True,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
def val_dataloader(self) -> DataLoader:
"""Return DataLoader for val data."""
- return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+ return DataLoader(
+ self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
def test_dataloader(self) -> DataLoader:
"""Return DataLoader for val data."""
- return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
-
+ return DataLoader(
+ self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
index 7322d7f..a004b8d 100644
--- a/text_recognizer/datasets/base_dataset.py
+++ b/text_recognizer/datasets/base_dataset.py
@@ -17,12 +17,14 @@ class BaseDataset(Dataset):
target_transform (Callable): Fucntion that takes a target and applies
target transforms.
"""
- def __init__(self,
- data: Union[Sequence, Tensor],
- targets: Union[Sequence, Tensor],
- transform: Callable = None,
- target_transform: Callable = None,
- ) -> None:
+
+ def __init__(
+ self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Callable = None,
+ target_transform: Callable = None,
+ ) -> None:
if len(data) != len(targets):
raise ValueError("Data and targets must be of equal length.")
self.data = data
@@ -30,11 +32,10 @@ class BaseDataset(Dataset):
self.transform = transform
self.target_transform = target_transform
-
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.data)
-
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Return a datum and its target, after processing by transforms.
@@ -56,7 +57,9 @@ class BaseDataset(Dataset):
return datum, target
-def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor:
+def convert_strings_to_labels(
+ strings: Sequence[str], mapping: Dict[str, int], length: int
+) -> Tensor:
"""
Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens,
and padded wiht the <P> token.
diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py
deleted file mode 100644
index e794605..0000000
--- a/text_recognizer/datasets/dataset.py
+++ /dev/null
@@ -1,152 +0,0 @@
-"""Abstract dataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import Tensor
-from torch.utils import data
-from torchvision.transforms import ToTensor
-
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class Dataset(data.Dataset):
- """Abstract class for with common methods for all datasets."""
-
- def __init__(
- self,
- train: bool,
- subsample_fraction: float = None,
- transform: Optional[List[Dict]] = None,
- target_transform: Optional[List[Dict]] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Initialization of Dataset class.
-
- Args:
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
- target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
- init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
- pad_token (Optional[str]): String representing the pad token. Defaults to None.
- eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
- lower (bool): Only use lower case letters. Defaults to False.
-
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
- """
- self.train = train
- self.split = "train" if self.train else "test"
-
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
-
- self._mapper = EmnistMapper(
- init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
- )
- self._input_shape = self._mapper.input_shape
- self._output_shape = self._mapper._num_classes
- self.num_classes = self.mapper.num_classes
-
- # Set transforms.
- self.transform = self._configure_transform(transform)
- self.target_transform = self._configure_target_transform(target_transform)
-
- self._data = None
- self._targets = None
-
- def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
- transform_list = []
- if transform is not None:
- for t in transform:
- t_type = t["type"]
- t_args = t["args"] or {}
- transform_list.append(getattr(transforms, t_type)(**t_args))
- else:
- transform_list.append(ToTensor())
- return transforms.Compose(transform_list)
-
- def _configure_target_transform(
- self, target_transform: List[Dict]
- ) -> transforms.Compose:
- target_transform_list = [torch.tensor]
- if target_transform is not None:
- for t in target_transform:
- t_type = t["type"]
- t_args = t["args"] or {}
- target_transform_list.append(getattr(transforms, t_type)(**t_args))
- return transforms.Compose(target_transform_list)
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
-
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return self._output_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the inverse mapping from character to index."""
- return self.mapper.inverse_mapping
-
- def _subsample(self) -> None:
- """Only this fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
- num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self._data = self.data[:num_subsample]
- self._targets = self.targets[:num_subsample]
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- raise NotImplementedError
-
- def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
- """Fetches samples from the dataset.
-
- Args:
- index (Union[int, torch.Tensor]): The indices of the samples to fetch.
-
- Raises:
- NotImplementedError: If the method is not implemented in child class.
-
- """
- raise NotImplementedError
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- raise NotImplementedError
diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py
index 7a2cab8..e3dc68c 100644
--- a/text_recognizer/datasets/download_utils.py
+++ b/text_recognizer/datasets/download_utils.py
@@ -63,11 +63,11 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
if filename.exists():
return
logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
- _download_url(metadata["url"], filename)
+ _download_url(metadata["url"], filename)
logger.info("Computing the SHA-256...")
sha256 = _compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
return filename
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
index e99dbfd..7c208c4 100644
--- a/text_recognizer/datasets/emnist.py
+++ b/text_recognizer/datasets/emnist.py
@@ -15,20 +15,23 @@ from torch.utils.data import random_split
from torchvision import transforms
from text_recognizer.datasets.base_dataset import BaseDataset
-from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info
+from text_recognizer.datasets.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
from text_recognizer.datasets.download_utils import download_dataset
SEED = 4711
NUM_SPECIAL_TOKENS = 4
-SAMPLE_TO_BALANCE = True
+SAMPLE_TO_BALANCE = True
RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
-PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist"
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json"
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
class EMNIST(BaseDataModule):
@@ -41,7 +44,9 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None:
+ def __init__(
+ self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
+ ) -> None:
super().__init__(batch_size, num_workers)
if not ESSENTIALS_FILENAME.exists():
_download_and_process_emnist()
@@ -64,20 +69,21 @@ class EMNIST(BaseDataModule):
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- data = f["x_train"][:]
- targets = f["y_train"][:]
-
- dataset_train = BaseDataset(data, targets, transform=self.transform)
+ self.x_train = f["x_train"][:]
+ self.y_train = f["y_train"][:]
+
+ dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform)
train_size = int(self.train_fraction * len(dataset_train))
val_size = len(dataset_train) - train_size
- self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator())
+ self.data_train, self.data_val = random_split(
+ dataset_train, [train_size, val_size], generator=torch.Generator()
+ )
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- data = f["x_test"][:]
- targets = f["y_test"][:]
- self.data_test = BaseDataset(data, targets, transform=self.transform)
-
+ self.x_test = f["x_test"][:]
+ self.y_test = f["y_test"][:]
+ self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self) -> str:
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
@@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
logger.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
- x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ x_train = (
+ data["dataset"]["train"][0, 0]["images"][0, 0]
+ .reshape(-1, 28, 28)
+ .swapaxes(1, 2)
+ )
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
- x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ x_test = (
+ data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ )
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
@@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
-
logger.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
@@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda
all_sampled_indices.append(sampled_indices)
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
- y_sampled= y[indices]
+ y_sampled = y[indices]
return x_sampled, y_sampled
@@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset.
iam_characters = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
# Also add special tokens for:
# - CTC blank token at index 0
@@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
-if __name__ == "__main__":
- load_print_info(EMNIST)
+def download_emnist() -> None:
+ """Download dataset from internet, if it does not exists, and displays info."""
+ load_and_print_info(EMNIST)
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
new file mode 100644
index 0000000..100b36a
--- /dev/null
+++ b/text_recognizer/datasets/emnist_essentials.json
@@ -0,0 +1 @@
+{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} \ No newline at end of file
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py
new file mode 100644
index 0000000..ae23feb
--- /dev/null
+++ b/text_recognizer/datasets/emnist_lines.py
@@ -0,0 +1,184 @@
+"""Dataset of generated text from EMNIST characters."""
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, Sequence
+
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torchvision import transforms
+
+from text_recognizer.datasets.base_dataset import BaseDataset
+from text_recognizer.datasets.base_data_module import BaseDataModule
+from text_recognizer.datasets.emnist import EMNIST
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+
+
+DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
+ESSENTIALS_FILENAME = (
+ Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json"
+)
+
+SEED = 4711
+IMAGE_HEIGHT = 56
+IMAGE_WIDTH = 1024
+IMAGE_X_PADDING = 28
+MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+
+
+class EMNISTLines(BaseDataModule):
+ """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
+
+ def __init__(
+ self,
+ augment: bool = True,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ max_length: int = 32,
+ min_overlap: float = 0.0,
+ max_overlap: float = 0.33,
+ num_train: int = 10_000,
+ num_val: int = 2_000,
+ num_test: int = 2_000,
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+
+ self.augment = augment
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_train = num_train
+ self.num_val = num_val
+ self.num_test = num_test
+
+ self.emnist = EMNIST()
+ self.mapping = self.emnist.mapping
+ max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING
+
+ if max_width <= IMAGE_WIDTH:
+ raise ValueError("max_width greater than IMAGE_WIDTH")
+
+ self.dims = (
+ self.emnist.dims[0],
+ self.emnist.dims[1],
+ self.emnist.dims[2] * self.max_length,
+ )
+
+ if self.max_length <= MAX_OUTPUT_LENGTH:
+ raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
+
+ self.output_dims = (MAX_OUTPUT_LENGTH, 1)
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+
+ @property
+ def data_filename(self) -> Path:
+ """Return name of dataset."""
+ return (
+ DATA_DIRNAME
+ / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
+ )
+
+ def prepare_data(self) -> None:
+ if self.data_filename.exists():
+ return
+ np.random.seed(SEED)
+ self._generate_data("train")
+ self._generate_data("val")
+ self._generate_data("test")
+
+ def setup(self, stage: str = None) -> None:
+ logger.info("EMNISTLinesDataset loading data from HDF5...")
+ if stage == "fit" or stage is None:
+ with h5py.File(self.data_filename, "r") as f:
+ x_train = f["x_train"][:]
+ y_train = torch.LongTensor(f["y_train"][:])
+ x_val = f["x_val"][:]
+ y_val = torch.LongTensor(f["y_val"][:])
+
+ self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment))
+ self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment))
+
+ if stage == "test" or stage is None:
+ with h5py.File(self.data_filename, "r") as f:
+ x_test = f["x_test"][:]
+ y_test = torch.LongTensor(f["y_test"][:])
+
+ self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False))
+
+ def __repr__(self) -> str:
+ """Return str about dataset."""
+ basic = (
+ "EMNISTLines2 Dataset\n" # pylint: disable=no-member
+ f"Min overlap: {self.min_overlap}\n"
+ f"Max overlap: {self.max_overlap}\n"
+ f"Num classes: {len(self.mapping)}\n"
+ f"Dims: {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ )
+ return basic + data
+
+ def _generate_data(self, split: str) -> None:
+ logger.info(f"EMNISTLines generating data for {split}...")
+ sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token
+
+ emnist = self.emnist
+ emnist.prepare_data()
+ emnist.setup()
+
+ if split == "train":
+ samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping)
+ num = self.num_train
+ elif split == "val":
+ samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping)
+ num = self.num_val
+ elif split == "test":
+ samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
+ num = self.num_test
+
+ DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(self.data_filename, "w") as f:
+ x, y = _create_dataset_of_images(
+ num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims
+ )
+ y = _convert_strings_to_labels(
+ y,
+ emnist.inverse_mapping,
+ length=MAX_OUTPUT_LENGTH
+ )
+ f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
+ f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
+
+def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict:
+ samples_by_char = defaultdict(list)
+ for sample, label in zip(samples, labels):
+ samples_by_char[mapping[label]].append(sample)
+ return samples_by_char
+
+
+def _construct_image_from_string():
+ pass
+
+
+def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
+ pass
+
+
+def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]:
+ images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
+ labels = []
+ for n in range(num_samples):
+ label = sentence_generator.generate()
+ crop = _construct_image_from_string()