diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 22:33:58 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 22:33:58 +0100 |
commit | e3741de333a3a43a7968241b6eccaaac66dd7b20 (patch) | |
tree | 7c50aee4ca61f77e95f1b038030292c64bbb86c2 | |
parent | aac452a2dc008338cb543549652da293c14b6b4e (diff) |
Working on EMNIST Lines dataset
-rw-r--r-- | data/EMNIST/raw/metadata.toml | 3 | ||||
-rw-r--r-- | noxfile.py | 9 | ||||
-rw-r--r-- | poetry.lock | 107 | ||||
-rw-r--r-- | pyproject.toml | 8 | ||||
-rw-r--r-- | tests/__init__.py (renamed from text_recognizer/tests/__init__.py) | 0 | ||||
-rw-r--r-- | tests/support/__init__.py (renamed from text_recognizer/tests/support/__init__.py) | 0 | ||||
-rw-r--r-- | tests/support/create_emnist_lines_support_files.py (renamed from text_recognizer/tests/support/create_emnist_lines_support_files.py) | 0 | ||||
-rw-r--r-- | tests/support/create_emnist_support_files.py (renamed from text_recognizer/tests/support/create_emnist_support_files.py) | 0 | ||||
-rw-r--r-- | tests/support/create_iam_lines_support_files.py (renamed from text_recognizer/tests/support/create_iam_lines_support_files.py) | 0 | ||||
-rw-r--r-- | tests/support/emnist_lines/Knox Ky<eos>.png (renamed from text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png) | bin | 2301 -> 2301 bytes | |||
-rw-r--r-- | tests/support/emnist_lines/ancillary beliefs and<eos>.png (renamed from text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png) | bin | 5424 -> 5424 bytes | |||
-rw-r--r-- | tests/support/emnist_lines/they<eos>.png (renamed from text_recognizer/tests/support/emnist_lines/they<eos>.png) | bin | 1391 -> 1391 bytes | |||
-rw-r--r-- | tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png (renamed from text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png) | bin | 5170 -> 5170 bytes | |||
-rw-r--r-- | tests/support/iam_lines/and came into the livingroom, where<eos>.png (renamed from text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png) | bin | 3617 -> 3617 bytes | |||
-rw-r--r-- | tests/support/iam_lines/his entrance. He came, almost falling<eos>.png (renamed from text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png) | bin | 3923 -> 3923 bytes | |||
-rw-r--r-- | tests/support/iam_paragraphs/a01-000u.jpg (renamed from text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg) | bin | 14890 -> 14890 bytes | |||
-rw-r--r-- | tests/test_character_predictor.py (renamed from text_recognizer/tests/test_character_predictor.py) | 0 | ||||
-rw-r--r-- | tests/test_line_predictor.py (renamed from text_recognizer/tests/test_line_predictor.py) | 0 | ||||
-rw-r--r-- | tests/test_paragraph_text_recognizer.py (renamed from text_recognizer/tests/test_paragraph_text_recognizer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/__init__.py | 38 | ||||
-rw-r--r-- | text_recognizer/datasets/base_data_module.py | 36 | ||||
-rw-r--r-- | text_recognizer/datasets/base_dataset.py | 21 | ||||
-rw-r--r-- | text_recognizer/datasets/dataset.py | 152 | ||||
-rw-r--r-- | text_recognizer/datasets/download_utils.py | 6 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist.py | 88 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_essentials.json | 1 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_lines.py | 184 |
27 files changed, 345 insertions, 308 deletions
diff --git a/data/EMNIST/raw/metadata.toml b/data/EMNIST/raw/metadata.toml deleted file mode 100644 index 10304ce..0000000 --- a/data/EMNIST/raw/metadata.toml +++ /dev/null @@ -1,3 +0,0 @@ -filename = 'gzip.zip' -md5 = '58c8d27c78d21e728a6bc7b3cc06412e' -url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' @@ -8,7 +8,14 @@ from nox.sessions import Session package = "text-recognizer" nox.options.sessions = "lint", "mypy", "pytype", "safety", "tests" -locations = "src", "tests", "noxfile.py", "docs/conf.py", "src/text_recognizer/tests" +locations = ( + "text_recognizer", + "training", + "tasks", + "tests", + "noxfile.py", + "docs/conf.py", +) def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> None: diff --git a/poetry.lock b/poetry.lock index a389e98..094043d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1730,14 +1730,14 @@ alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"] [[package]] name = "scipy" -version = "1.5.4" +version = "1.6.1" description = "SciPy: Scientific Library for Python" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] -numpy = ">=1.14.5" +numpy = ">=1.16.5" [[package]] name = "send2trash" @@ -2033,7 +2033,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] name = "torch" -version = "1.7.1" +version = "1.8.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" optional = false @@ -2053,7 +2053,7 @@ python-versions = ">=3.5" [[package]] name = "torchvision" -version = "0.8.2" +version = "0.9.0" description = "image and video datasets and models for torch deep learning" category = "main" optional = false @@ -2062,7 +2062,7 @@ python-versions = "*" [package.dependencies] numpy = "*" pillow = ">=4.1.1" -torch = "1.7.1" +torch = "1.8.0" [package.extras] scipy = ["scipy"] @@ -2279,7 +2279,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "45c6282e27a0231e3dc3c951a9540f3cd2f6d6eb3d3eda1eab84168014d8062a" +content-hash = "82834b5e12f38b06a306f1b8c05ed6ee33d671c747de40615d2bd36f3f138279" [metadata.files] absl-py = [ @@ -3385,31 +3385,25 @@ scikit-learn = [ {file = "scikit_learn-0.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea"}, ] scipy = [ - {file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"}, - {file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"}, - {file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"}, - {file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"}, - {file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"}, - {file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"}, - {file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"}, - {file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"}, - {file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"}, - {file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"}, - {file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"}, - {file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"}, - {file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"}, + {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"}, + {file = "scipy-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:e79570979ccdc3d165456dd62041d9556fb9733b86b4b6d818af7a0afc15f092"}, + {file = "scipy-1.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:a423533c55fec61456dedee7b6ee7dce0bb6bfa395424ea374d25afa262be261"}, + {file = "scipy-1.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:33d6b7df40d197bdd3049d64e8e680227151673465e5d85723b3b8f6b15a6ced"}, + {file = "scipy-1.6.1-cp37-cp37m-win32.whl", hash = "sha256:6725e3fbb47da428794f243864f2297462e9ee448297c93ed1dcbc44335feb78"}, + {file = "scipy-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:5fa9c6530b1661f1370bcd332a1e62ca7881785cc0f80c0d559b636567fab63c"}, + {file = "scipy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bd50daf727f7c195e26f27467c85ce653d41df4358a25b32434a50d8870fc519"}, + {file = "scipy-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f46dd15335e8a320b0fb4685f58b7471702234cba8bb3442b69a3e1dc329c345"}, + {file = "scipy-1.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0e5b0ccf63155d90da576edd2768b66fb276446c371b73841e3503be1d63fb5d"}, + {file = "scipy-1.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2481efbb3740977e3c831edfd0bd9867be26387cacf24eb5e366a6a374d3d00d"}, + {file = "scipy-1.6.1-cp38-cp38-win32.whl", hash = "sha256:68cb4c424112cd4be886b4d979c5497fba190714085f46b8ae67a5e4416c32b4"}, + {file = "scipy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:5f331eeed0297232d2e6eea51b54e8278ed8bb10b099f69c44e2558c090d06bf"}, + {file = "scipy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8a51d33556bf70367452d4d601d1742c0e806cd0194785914daf19775f0e67"}, + {file = "scipy-1.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:83bf7c16245c15bc58ee76c5418e46ea1811edcc2e2b03041b804e46084ab627"}, + {file = "scipy-1.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:794e768cc5f779736593046c9714e0f3a5940bc6dcc1dba885ad64cbfb28e9f0"}, + {file = "scipy-1.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5da5471aed911fe7e52b86bf9ea32fb55ae93e2f0fac66c32e58897cfb02fa07"}, + {file = "scipy-1.6.1-cp39-cp39-win32.whl", hash = "sha256:8e403a337749ed40af60e537cc4d4c03febddcc56cd26e774c9b1b600a70d3e4"}, + {file = "scipy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a5193a098ae9f29af283dcf0041f762601faf2e595c0db1da929875b7570353f"}, + {file = "scipy-1.6.1.tar.gz", hash = "sha256:c4fceb864890b6168e79b0e714c585dbe2fd4222768ee90bc1aa0f8218691b11"}, ] send2trash = [ {file = "Send2Trash-1.5.0-py3-none-any.whl", hash = "sha256:f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b"}, @@ -3544,32 +3538,41 @@ toml = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] torch = [ - {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"}, - {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"}, - {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"}, - {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"}, - {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"}, - {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"}, - {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"}, - {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"}, - {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"}, - {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"}, - {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"}, - {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"}, + {file = "torch-1.8.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:78b84115fd03f4587382a38b0da98cdd1827117806c80ebf97843a64213816cc"}, + {file = "torch-1.8.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:86f13f324fd87870bd0d37864f4f5814dc27f9e7ed9ea222f1cc7d7dc01a8ffe"}, + {file = "torch-1.8.0-cp36-cp36m-win_amd64.whl", hash = "sha256:394a99d777e487e773e0172cb0a0bce5b411e3090d89844e8dd55618be9bc970"}, + {file = "torch-1.8.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:229a8dc38059ef6c7171f3f4f49c51e8a3d9644ce6c32dcddd9f1bac888a78aa"}, + {file = "torch-1.8.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:6ecdbd4494b4bf2d31a24ddfbdff32bd995389bc8662a454bd40d3e8ce202907"}, + {file = "torch-1.8.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:08aff0383e868f1e9882b732bbe6934defab690ad1745a03d5f1a150a4e1aeba"}, + {file = "torch-1.8.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c87c7b0fd31c331968674cb73e82396a622b06a8e20425584922b767f2ffb259"}, + {file = "torch-1.8.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:b9d6c8c457b90b5167f3ab0bd1ff7193a06935533176bc6d41e1763d353e9740"}, + {file = "torch-1.8.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:fa1e391cca3937d5dea31f31a1a80a01bd4a8062c039448c254bbf5a58eb0787"}, + {file = "torch-1.8.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:affef9bce6eed232308dd89d55d3a37a105f35460f4705375980d27154c51e24"}, + {file = "torch-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:287a7677df844bf2c4425698fd6d9434065211219cd7fd96000ed981c4d92288"}, + {file = "torch-1.8.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1b58f70c150e066bcd7401a3bdfad661a04244817a5dac9990b5367523887d3f"}, + {file = "torch-1.8.0-cp38-none-macosx_11_1_arm64.whl", hash = "sha256:923856c2e6e53d5a747d83ff40faadd791d27cea2fd881b8d6990ea269f47572"}, + {file = "torch-1.8.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:2318fac860ae73dc6486c0de2223674d9ef6139fc75f157af2bf8dce4fca5524"}, + {file = "torch-1.8.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:05b114cb793816cd140794d5874f463972cb639f3b55d3a060f21fd066f5b629"}, + {file = "torch-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:7438431e03af793979cb1a9b5dd9f399b38461748e9f21f60e36149ee215d751"}, + {file = "torch-1.8.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:d98d167994d2e30df61a98eaca1684c50761f096d7f76c0c99789ac8cea50b55"}, ] torch-summary = [ {file = "torch-summary-1.4.3.tar.gz", hash = "sha256:2dcbc1dfd07dca9f4080bcacdaf90db3f2fc28efee348c8fba9033039b0e8c82"}, {file = "torch_summary-1.4.3-py3-none-any.whl", hash = "sha256:a0a76916bd11d054fd3863dc7c474971922badfbc13d6404f9eddd297041f094"}, ] torchvision = [ - {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"}, - {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"}, - {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"}, - {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"}, - {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"}, - {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"}, - {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"}, - {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"}, + {file = "torchvision-0.9.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:63052147c776d9f93385410c1d5a791386eb0cb5e1b93c7feac686f8dbe6eb06"}, + {file = "torchvision-0.9.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:874714f30822d4c1160071dac004d48ae641bfdccccbb497098c86f6589ec0f1"}, + {file = "torchvision-0.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:b9f71f62725776495071b875494af86615f225b1a40902f5df452da5cfde0510"}, + {file = "torchvision-0.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24b505dbcf3cb8da49d4b1447543c1021b699c84fc3701523101b62ee4adf097"}, + {file = "torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:d90750ae76a0cad8ffb6b509b30412dcd102d27d5f34f7184b289b6687de580e"}, + {file = "torchvision-0.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8791da742c24344646a4ac36adee9327491f7fff7607dffe352402b5bf25ea21"}, + {file = "torchvision-0.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fa302f6e8fe33a8d5c6649e659655c0427eee662fe22ce69eb56fa402b520c26"}, + {file = "torchvision-0.9.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8fce78a59959f4bb4780a78c2277d617e44da7bc270bc449ff403187f6b587fc"}, + {file = "torchvision-0.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:2252bc63fcccb27785726dd9d0d9a97432657a5d139390bf93cd6bdf227a4401"}, + {file = "torchvision-0.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:421bda7131f3c0eae2260f10174ac3c49e54183b33acb927b4b572f4cd90066d"}, + {file = "torchvision-0.9.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d85d405e8cf694c1f85da7f0496ea69dd4f8d8dafbdad1e29bcdc4c621fc5cf0"}, + {file = "torchvision-0.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:b03275f351feffaf7450d234ffb57cce26ff5e696d01ef5f543de205f18849a9"}, ] tornado = [ {file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"}, diff --git a/pyproject.toml b/pyproject.toml index ef75edf..33d539e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ sphinx_rtd_theme = "^0.4.3" boltons = "^20.1.0" h5py = "^3.2.1" toml = "^0.10.1" -torch = "^1.7.0" -torchvision = "^0.8.1" +torch = "^1.8.0" +torchvision = "^0.9.0" loguru = "^0.5.0" matplotlib = "^3.2.1" tqdm = "^4.46.1" @@ -64,12 +64,13 @@ gpustat = "^0.6.0" redlock-py = "^1.0.8" wandb = "^0.10.11" graphviz = "^0.16" +scipy = "^1.6.1" [tool.coverage.report] fail_under = 50 [tool.poetry.scripts] -download-emnist = "text_recognizer.datasets.util:download_emnist" +download-emnist = "text_recognizer.datasets.emnist:download_emnist" download-iam = "text_recognizer.datasets.iam_dataset:main" create-emnist-support-files = "text_recognizer.tests.support.create_emnist_support_files:create_emnist_support_files" create-emnist-lines-datasets = "text_recognizer.datasets.emnist_lines_dataset:create_datasets" @@ -78,7 +79,6 @@ prepare-experiments = "training.prepare_experiments:run_cli" run-experiment = "training.run_experiment:run_cli" - [build-system] requires = ["poetry>=0.12"] build-backend = "poetry.masonry.api" diff --git a/text_recognizer/tests/__init__.py b/tests/__init__.py index 18ff212..18ff212 100644 --- a/text_recognizer/tests/__init__.py +++ b/tests/__init__.py diff --git a/text_recognizer/tests/support/__init__.py b/tests/support/__init__.py index a265ede..a265ede 100644 --- a/text_recognizer/tests/support/__init__.py +++ b/tests/support/__init__.py diff --git a/text_recognizer/tests/support/create_emnist_lines_support_files.py b/tests/support/create_emnist_lines_support_files.py index 9abe143..9abe143 100644 --- a/text_recognizer/tests/support/create_emnist_lines_support_files.py +++ b/tests/support/create_emnist_lines_support_files.py diff --git a/text_recognizer/tests/support/create_emnist_support_files.py b/tests/support/create_emnist_support_files.py index f9ff030..f9ff030 100644 --- a/text_recognizer/tests/support/create_emnist_support_files.py +++ b/tests/support/create_emnist_support_files.py diff --git a/text_recognizer/tests/support/create_iam_lines_support_files.py b/tests/support/create_iam_lines_support_files.py index 50f9e3d..50f9e3d 100644 --- a/text_recognizer/tests/support/create_iam_lines_support_files.py +++ b/tests/support/create_iam_lines_support_files.py diff --git a/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/tests/support/emnist_lines/Knox Ky<eos>.png Binary files differindex b7d0618..b7d0618 100644 --- a/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png +++ b/tests/support/emnist_lines/Knox Ky<eos>.png diff --git a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/tests/support/emnist_lines/ancillary beliefs and<eos>.png Binary files differindex 14a8cf3..14a8cf3 100644 --- a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png +++ b/tests/support/emnist_lines/ancillary beliefs and<eos>.png diff --git a/text_recognizer/tests/support/emnist_lines/they<eos>.png b/tests/support/emnist_lines/they<eos>.png Binary files differindex 7f05951..7f05951 100644 --- a/text_recognizer/tests/support/emnist_lines/they<eos>.png +++ b/tests/support/emnist_lines/they<eos>.png diff --git a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png Binary files differindex 6eeb642..6eeb642 100644 --- a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png +++ b/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png diff --git a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/tests/support/iam_lines/and came into the livingroom, where<eos>.png Binary files differindex 4974cf8..4974cf8 100644 --- a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png +++ b/tests/support/iam_lines/and came into the livingroom, where<eos>.png diff --git a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png Binary files differindex a731245..a731245 100644 --- a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png +++ b/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png diff --git a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/tests/support/iam_paragraphs/a01-000u.jpg Binary files differindex d9753b6..d9753b6 100644 --- a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg +++ b/tests/support/iam_paragraphs/a01-000u.jpg diff --git a/text_recognizer/tests/test_character_predictor.py b/tests/test_character_predictor.py index 01bda78..01bda78 100644 --- a/text_recognizer/tests/test_character_predictor.py +++ b/tests/test_character_predictor.py diff --git a/text_recognizer/tests/test_line_predictor.py b/tests/test_line_predictor.py index eede4d4..eede4d4 100644 --- a/text_recognizer/tests/test_line_predictor.py +++ b/tests/test_line_predictor.py diff --git a/text_recognizer/tests/test_paragraph_text_recognizer.py b/tests/test_paragraph_text_recognizer.py index 3e280b9..3e280b9 100644 --- a/text_recognizer/tests/test_paragraph_text_recognizer.py +++ b/tests/test_paragraph_text_recognizer.py diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py index a6c1c59..2727b20 100644 --- a/text_recognizer/datasets/__init__.py +++ b/text_recognizer/datasets/__init__.py @@ -1,39 +1 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataset -from .emnist_lines_dataset import ( - construct_image_from_string, - EmnistLinesDataset, - get_samples_by_character, -) -from .iam_dataset import IamDataset -from .iam_lines_dataset import IamLinesDataset -from .iam_paragraphs_dataset import IamParagraphsDataset -from .iam_preprocessor import load_metadata, Preprocessor -from .transforms import AddTokens, Transpose -from .util import ( - _download_raw_dataset, - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -__all__ = [ - "_download_raw_dataset", - "AddTokens", - "compute_sha256", - "construct_image_from_string", - "DATA_DIRNAME", - "download_url", - "EmnistDataset", - "EmnistMapper", - "EmnistLinesDataset", - "get_samples_by_character", - "load_metadata", - "IamDataset", - "IamLinesDataset", - "IamParagraphsDataset", - "Preprocessor", - "Transpose", -] diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py index 09a0a43..830b39b 100644 --- a/text_recognizer/datasets/base_data_module.py +++ b/text_recognizer/datasets/base_data_module.py @@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None: class BaseDataModule(pl.LightningDataModule): """Base PyTorch Lightning DataModule.""" - + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: super().__init__() self.batch_size = batch_size @@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule): def config(self) -> Dict: """Return important settings of the dataset.""" - return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping} + return { + "input_dim": self.dims, + "output_dims": self.output_dims, + "mapping": self.mapping, + } def prepare_data(self) -> None: """Prepare data for training.""" pass - def setup(self, stage: Any = None) -> None: + def setup(self, stage: str = None) -> None: """Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and @@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule): self.data_val = None self.data_test = None - def train_dataloader(self) -> DataLoader: """Retun DataLoader for train data.""" - return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def val_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_val, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def test_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) - + return DataLoader( + self.data_test, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py index 7322d7f..a004b8d 100644 --- a/text_recognizer/datasets/base_dataset.py +++ b/text_recognizer/datasets/base_dataset.py @@ -17,12 +17,14 @@ class BaseDataset(Dataset): target_transform (Callable): Fucntion that takes a target and applies target transforms. """ - def __init__(self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length.") self.data = data @@ -30,11 +32,10 @@ class BaseDataset(Dataset): self.transform = transform self.target_transform = target_transform - def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a datum and its target, after processing by transforms. @@ -56,7 +57,9 @@ class BaseDataset(Dataset): return datum, target -def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor: +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: """ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens, and padded wiht the <P> token. diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py deleted file mode 100644 index e794605..0000000 --- a/text_recognizer/datasets/dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Abstract dataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.utils import data -from torchvision.transforms import ToTensor - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.datasets.util import EmnistMapper - - -class Dataset(data.Dataset): - """Abstract class for with common methods for all datasets.""" - - def __init__( - self, - train: bool, - subsample_fraction: float = None, - transform: Optional[List[Dict]] = None, - target_transform: Optional[List[Dict]] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Initialization of Dataset class. - - Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. - target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): Only use lower case letters. Defaults to False. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.train = train - self.split = "train" if self.train else "test" - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower - ) - self._input_shape = self._mapper.input_shape - self._output_shape = self._mapper._num_classes - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = self._configure_transform(transform) - self.target_transform = self._configure_target_transform(target_transform) - - self._data = None - self._targets = None - - def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: - transform_list = [] - if transform is not None: - for t in transform: - t_type = t["type"] - t_args = t["args"] or {} - transform_list.append(getattr(transforms, t_type)(**t_args)) - else: - transform_list.append(ToTensor()) - return transforms.Compose(transform_list) - - def _configure_target_transform( - self, target_transform: List[Dict] - ) -> transforms.Compose: - target_transform_list = [torch.tensor] - if target_transform is not None: - for t in target_transform: - t_type = t["type"] - t_args = t["args"] or {} - target_transform_list.append(getattr(transforms, t_type)(**t_args)) - return transforms.Compose(target_transform_list) - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self._data = self.data[:num_subsample] - self._targets = self.targets[:num_subsample] - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - raise NotImplementedError - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. - - Raises: - NotImplementedError: If the method is not implemented in child class. - - """ - raise NotImplementedError - - def __repr__(self) -> str: - """Returns information about the dataset.""" - raise NotImplementedError diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py index 7a2cab8..e3dc68c 100644 --- a/text_recognizer/datasets/download_utils.py +++ b/text_recognizer/datasets/download_utils.py @@ -63,11 +63,11 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: if filename.exists(): return logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") - _download_url(metadata["url"], filename) + _download_url(metadata["url"], filename) logger.info("Computing the SHA-256...") sha256 = _compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) return filename diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py index e99dbfd..7c208c4 100644 --- a/text_recognizer/datasets/emnist.py +++ b/text_recognizer/datasets/emnist.py @@ -15,20 +15,23 @@ from torch.utils.data import random_split from torchvision import transforms from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info +from text_recognizer.datasets.base_data_module import ( + BaseDataModule, + load_and_print_info, +) from text_recognizer.datasets.download_utils import download_dataset SEED = 4711 NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True +SAMPLE_TO_BALANCE = True RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" class EMNIST(BaseDataModule): @@ -41,7 +44,9 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None: + def __init__( + self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 + ) -> None: super().__init__(batch_size, num_workers) if not ESSENTIALS_FILENAME.exists(): _download_and_process_emnist() @@ -64,20 +69,21 @@ class EMNIST(BaseDataModule): def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_train"][:] - targets = f["y_train"][:] - - dataset_train = BaseDataset(data, targets, transform=self.transform) + self.x_train = f["x_train"][:] + self.y_train = f["y_train"][:] + + dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform) train_size = int(self.train_fraction * len(dataset_train)) val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator()) + self.data_train, self.data_val = random_split( + dataset_train, [train_size, val_size], generator=torch.Generator() + ) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_test"][:] - targets = f["y_test"][:] - self.data_test = BaseDataset(data, targets, transform=self.transform) - + self.x_test = f["x_test"][:] + self.y_test = f["y_test"][:] + self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self) -> str: basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" @@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: logger.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") - x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_train = ( + data["dataset"]["train"][0, 0]["images"][0, 0] + .reshape(-1, 28, 28) + .swapaxes(1, 2) + ) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_test = ( + data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + ) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: @@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: @@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda all_sampled_indices.append(sampled_indices) indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] - y_sampled= y[indices] + y_sampled = y[indices] return x_sampled, y_sampled @@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset. iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] # Also add special tokens for: # - CTC blank token at index 0 @@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters] -if __name__ == "__main__": - load_print_info(EMNIST) +def download_emnist() -> None: + """Download dataset from internet, if it does not exists, and displays info.""" + load_and_print_info(EMNIST) diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..100b36a --- /dev/null +++ b/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
\ No newline at end of file diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py new file mode 100644 index 0000000..ae23feb --- /dev/null +++ b/text_recognizer/datasets/emnist_lines.py @@ -0,0 +1,184 @@ +"""Dataset of generated text from EMNIST characters.""" +from collections import defaultdict +from pathlib import Path +from typing import Dict, Sequence + +import h5py +from loguru import logger +import numpy as np +import torch +from torchvision import transforms + +from text_recognizer.datasets.base_dataset import BaseDataset +from text_recognizer.datasets.base_data_module import BaseDataModule +from text_recognizer.datasets.emnist import EMNIST +from text_recognizer.datasets.sentence_generator import SentenceGenerator + + +DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" +) + +SEED = 4711 +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 +IMAGE_X_PADDING = 28 +MAX_OUTPUT_LENGTH = 89 # Same as IAMLines + + +class EMNISTLines(BaseDataModule): + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + + def __init__( + self, + augment: bool = True, + batch_size: int = 128, + num_workers: int = 0, + max_length: int = 32, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__(batch_size, num_workers) + + self.augment = augment + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + self.emnist = EMNIST() + self.mapping = self.emnist.mapping + max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING + + if max_width <= IMAGE_WIDTH: + raise ValueError("max_width greater than IMAGE_WIDTH") + + self.dims = ( + self.emnist.dims[0], + self.emnist.dims[1], + self.emnist.dims[2] * self.max_length, + ) + + if self.max_length <= MAX_OUTPUT_LENGTH: + raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") + + self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.data_train = None + self.data_val = None + self.data_test = None + + @property + def data_filename(self) -> Path: + """Return name of dataset.""" + return ( + DATA_DIRNAME + / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" + ) + + def prepare_data(self) -> None: + if self.data_filename.exists(): + return + np.random.seed(SEED) + self._generate_data("train") + self._generate_data("val") + self._generate_data("test") + + def setup(self, stage: str = None) -> None: + logger.info("EMNISTLinesDataset loading data from HDF5...") + if stage == "fit" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_train = f["x_train"][:] + y_train = torch.LongTensor(f["y_train"][:]) + x_val = f["x_val"][:] + y_val = torch.LongTensor(f["y_val"][:]) + + self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment)) + self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment)) + + if stage == "test" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_test = f["x_test"][:] + y_test = torch.LongTensor(f["y_test"][:]) + + self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False)) + + def __repr__(self) -> str: + """Return str about dataset.""" + basic = ( + "EMNISTLines2 Dataset\n" # pylint: disable=no-member + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {len(self.mapping)}\n" + f"Dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + ) + return basic + data + + def _generate_data(self, split: str) -> None: + logger.info(f"EMNISTLines generating data for {split}...") + sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token + + emnist = self.emnist + emnist.prepare_data() + emnist.setup() + + if split == "train": + samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + num = self.num_train + elif split == "val": + samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + num = self.num_val + elif split == "test": + samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) + num = self.num_test + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "w") as f: + x, y = _create_dataset_of_images( + num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims + ) + y = _convert_strings_to_labels( + y, + emnist.inverse_mapping, + length=MAX_OUTPUT_LENGTH + ) + f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") + f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") + +def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict: + samples_by_char = defaultdict(list) + for sample, label in zip(samples, labels): + samples_by_char[mapping[label]].append(sample) + return samples_by_char + + +def _construct_image_from_string(): + pass + + +def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): + pass + + +def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]: + images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + labels = [] + for n in range(num_samples): + label = sentence_generator.generate() + crop = _construct_image_from_string() |