From e3741de333a3a43a7968241b6eccaaac66dd7b20 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 21 Mar 2021 22:33:58 +0100 Subject: Working on EMNIST Lines dataset --- data/EMNIST/raw/metadata.toml | 3 - noxfile.py | 9 +- poetry.lock | 107 ++++++------ pyproject.toml | 8 +- tests/__init__.py | 1 + tests/support/__init__.py | 2 + tests/support/create_emnist_lines_support_files.py | 51 ++++++ tests/support/create_emnist_support_files.py | 30 ++++ tests/support/create_iam_lines_support_files.py | 50 ++++++ tests/support/emnist_lines/Knox Ky.png | Bin 0 -> 2301 bytes .../emnist_lines/ancillary beliefs and.png | Bin 0 -> 5424 bytes tests/support/emnist_lines/they.png | Bin 0 -> 1391 bytes .../He rose from his breakfast-nook bench.png | Bin 0 -> 5170 bytes .../and came into the livingroom, where.png | Bin 0 -> 3617 bytes .../his entrance. He came, almost falling.png | Bin 0 -> 3923 bytes tests/support/iam_paragraphs/a01-000u.jpg | Bin 0 -> 14890 bytes tests/test_character_predictor.py | 31 ++++ tests/test_line_predictor.py | 35 ++++ tests/test_paragraph_text_recognizer.py | 37 +++++ text_recognizer/datasets/__init__.py | 38 ----- text_recognizer/datasets/base_data_module.py | 36 +++- text_recognizer/datasets/base_dataset.py | 21 ++- text_recognizer/datasets/dataset.py | 152 ----------------- text_recognizer/datasets/download_utils.py | 6 +- text_recognizer/datasets/emnist.py | 88 +++++----- text_recognizer/datasets/emnist_essentials.json | 1 + text_recognizer/datasets/emnist_lines.py | 184 +++++++++++++++++++++ text_recognizer/tests/__init__.py | 1 - text_recognizer/tests/support/__init__.py | 2 - .../support/create_emnist_lines_support_files.py | 51 ------ .../tests/support/create_emnist_support_files.py | 30 ---- .../support/create_iam_lines_support_files.py | 50 ------ .../tests/support/emnist_lines/Knox Ky.png | Bin 2301 -> 0 bytes .../emnist_lines/ancillary beliefs and.png | Bin 5424 -> 0 bytes .../tests/support/emnist_lines/they.png | Bin 1391 -> 0 bytes .../He rose from his breakfast-nook bench.png | Bin 5170 -> 0 bytes .../and came into the livingroom, where.png | Bin 3617 -> 0 bytes .../his entrance. He came, almost falling.png | Bin 3923 -> 0 bytes .../tests/support/iam_paragraphs/a01-000u.jpg | Bin 14890 -> 0 bytes text_recognizer/tests/test_character_predictor.py | 31 ---- text_recognizer/tests/test_line_predictor.py | 35 ---- .../tests/test_paragraph_text_recognizer.py | 37 ----- 42 files changed, 582 insertions(+), 545 deletions(-) delete mode 100644 data/EMNIST/raw/metadata.toml create mode 100644 tests/__init__.py create mode 100644 tests/support/__init__.py create mode 100644 tests/support/create_emnist_lines_support_files.py create mode 100644 tests/support/create_emnist_support_files.py create mode 100644 tests/support/create_iam_lines_support_files.py create mode 100644 tests/support/emnist_lines/Knox Ky.png create mode 100644 tests/support/emnist_lines/ancillary beliefs and.png create mode 100644 tests/support/emnist_lines/they.png create mode 100644 tests/support/iam_lines/He rose from his breakfast-nook bench.png create mode 100644 tests/support/iam_lines/and came into the livingroom, where.png create mode 100644 tests/support/iam_lines/his entrance. He came, almost falling.png create mode 100644 tests/support/iam_paragraphs/a01-000u.jpg create mode 100644 tests/test_character_predictor.py create mode 100644 tests/test_line_predictor.py create mode 100644 tests/test_paragraph_text_recognizer.py delete mode 100644 text_recognizer/datasets/dataset.py create mode 100644 text_recognizer/datasets/emnist_essentials.json create mode 100644 text_recognizer/datasets/emnist_lines.py delete mode 100644 text_recognizer/tests/__init__.py delete mode 100644 text_recognizer/tests/support/__init__.py delete mode 100644 text_recognizer/tests/support/create_emnist_lines_support_files.py delete mode 100644 text_recognizer/tests/support/create_emnist_support_files.py delete mode 100644 text_recognizer/tests/support/create_iam_lines_support_files.py delete mode 100644 text_recognizer/tests/support/emnist_lines/Knox Ky.png delete mode 100644 text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png delete mode 100644 text_recognizer/tests/support/emnist_lines/they.png delete mode 100644 text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png delete mode 100644 text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png delete mode 100644 text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png delete mode 100644 text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg delete mode 100644 text_recognizer/tests/test_character_predictor.py delete mode 100644 text_recognizer/tests/test_line_predictor.py delete mode 100644 text_recognizer/tests/test_paragraph_text_recognizer.py diff --git a/data/EMNIST/raw/metadata.toml b/data/EMNIST/raw/metadata.toml deleted file mode 100644 index 10304ce..0000000 --- a/data/EMNIST/raw/metadata.toml +++ /dev/null @@ -1,3 +0,0 @@ -filename = 'gzip.zip' -md5 = '58c8d27c78d21e728a6bc7b3cc06412e' -url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' diff --git a/noxfile.py b/noxfile.py index 098a551..0e2ac7b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -8,7 +8,14 @@ from nox.sessions import Session package = "text-recognizer" nox.options.sessions = "lint", "mypy", "pytype", "safety", "tests" -locations = "src", "tests", "noxfile.py", "docs/conf.py", "src/text_recognizer/tests" +locations = ( + "text_recognizer", + "training", + "tasks", + "tests", + "noxfile.py", + "docs/conf.py", +) def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> None: diff --git a/poetry.lock b/poetry.lock index a389e98..094043d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1730,14 +1730,14 @@ alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"] [[package]] name = "scipy" -version = "1.5.4" +version = "1.6.1" description = "SciPy: Scientific Library for Python" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] -numpy = ">=1.14.5" +numpy = ">=1.16.5" [[package]] name = "send2trash" @@ -2033,7 +2033,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] name = "torch" -version = "1.7.1" +version = "1.8.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" optional = false @@ -2053,7 +2053,7 @@ python-versions = ">=3.5" [[package]] name = "torchvision" -version = "0.8.2" +version = "0.9.0" description = "image and video datasets and models for torch deep learning" category = "main" optional = false @@ -2062,7 +2062,7 @@ python-versions = "*" [package.dependencies] numpy = "*" pillow = ">=4.1.1" -torch = "1.7.1" +torch = "1.8.0" [package.extras] scipy = ["scipy"] @@ -2279,7 +2279,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "45c6282e27a0231e3dc3c951a9540f3cd2f6d6eb3d3eda1eab84168014d8062a" +content-hash = "82834b5e12f38b06a306f1b8c05ed6ee33d671c747de40615d2bd36f3f138279" [metadata.files] absl-py = [ @@ -3385,31 +3385,25 @@ scikit-learn = [ {file = "scikit_learn-0.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea"}, ] scipy = [ - {file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"}, - {file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"}, - {file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"}, - {file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"}, - {file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"}, - {file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"}, - {file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"}, - {file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"}, - {file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"}, - {file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"}, - {file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"}, - {file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"}, - {file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"}, + {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"}, + {file = "scipy-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:e79570979ccdc3d165456dd62041d9556fb9733b86b4b6d818af7a0afc15f092"}, + {file = "scipy-1.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:a423533c55fec61456dedee7b6ee7dce0bb6bfa395424ea374d25afa262be261"}, + {file = "scipy-1.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:33d6b7df40d197bdd3049d64e8e680227151673465e5d85723b3b8f6b15a6ced"}, + {file = "scipy-1.6.1-cp37-cp37m-win32.whl", hash = "sha256:6725e3fbb47da428794f243864f2297462e9ee448297c93ed1dcbc44335feb78"}, + {file = "scipy-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:5fa9c6530b1661f1370bcd332a1e62ca7881785cc0f80c0d559b636567fab63c"}, + {file = "scipy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bd50daf727f7c195e26f27467c85ce653d41df4358a25b32434a50d8870fc519"}, + {file = "scipy-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f46dd15335e8a320b0fb4685f58b7471702234cba8bb3442b69a3e1dc329c345"}, + {file = "scipy-1.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0e5b0ccf63155d90da576edd2768b66fb276446c371b73841e3503be1d63fb5d"}, + {file = "scipy-1.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2481efbb3740977e3c831edfd0bd9867be26387cacf24eb5e366a6a374d3d00d"}, + {file = "scipy-1.6.1-cp38-cp38-win32.whl", hash = "sha256:68cb4c424112cd4be886b4d979c5497fba190714085f46b8ae67a5e4416c32b4"}, + {file = "scipy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:5f331eeed0297232d2e6eea51b54e8278ed8bb10b099f69c44e2558c090d06bf"}, + {file = "scipy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8a51d33556bf70367452d4d601d1742c0e806cd0194785914daf19775f0e67"}, + {file = "scipy-1.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:83bf7c16245c15bc58ee76c5418e46ea1811edcc2e2b03041b804e46084ab627"}, + {file = "scipy-1.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:794e768cc5f779736593046c9714e0f3a5940bc6dcc1dba885ad64cbfb28e9f0"}, + {file = "scipy-1.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5da5471aed911fe7e52b86bf9ea32fb55ae93e2f0fac66c32e58897cfb02fa07"}, + {file = "scipy-1.6.1-cp39-cp39-win32.whl", hash = "sha256:8e403a337749ed40af60e537cc4d4c03febddcc56cd26e774c9b1b600a70d3e4"}, + {file = "scipy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a5193a098ae9f29af283dcf0041f762601faf2e595c0db1da929875b7570353f"}, + {file = "scipy-1.6.1.tar.gz", hash = "sha256:c4fceb864890b6168e79b0e714c585dbe2fd4222768ee90bc1aa0f8218691b11"}, ] send2trash = [ {file = "Send2Trash-1.5.0-py3-none-any.whl", hash = "sha256:f1691922577b6fa12821234aeb57599d887c4900b9ca537948d2dac34aea888b"}, @@ -3544,32 +3538,41 @@ toml = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] torch = [ - {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"}, - {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"}, - {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"}, - {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"}, - {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"}, - {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"}, - {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"}, - {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"}, - {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"}, - {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"}, - {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"}, - {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"}, + {file = "torch-1.8.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:78b84115fd03f4587382a38b0da98cdd1827117806c80ebf97843a64213816cc"}, + {file = "torch-1.8.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:86f13f324fd87870bd0d37864f4f5814dc27f9e7ed9ea222f1cc7d7dc01a8ffe"}, + {file = "torch-1.8.0-cp36-cp36m-win_amd64.whl", hash = "sha256:394a99d777e487e773e0172cb0a0bce5b411e3090d89844e8dd55618be9bc970"}, + {file = "torch-1.8.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:229a8dc38059ef6c7171f3f4f49c51e8a3d9644ce6c32dcddd9f1bac888a78aa"}, + {file = "torch-1.8.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:6ecdbd4494b4bf2d31a24ddfbdff32bd995389bc8662a454bd40d3e8ce202907"}, + {file = "torch-1.8.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:08aff0383e868f1e9882b732bbe6934defab690ad1745a03d5f1a150a4e1aeba"}, + {file = "torch-1.8.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c87c7b0fd31c331968674cb73e82396a622b06a8e20425584922b767f2ffb259"}, + {file = "torch-1.8.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:b9d6c8c457b90b5167f3ab0bd1ff7193a06935533176bc6d41e1763d353e9740"}, + {file = "torch-1.8.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:fa1e391cca3937d5dea31f31a1a80a01bd4a8062c039448c254bbf5a58eb0787"}, + {file = "torch-1.8.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:affef9bce6eed232308dd89d55d3a37a105f35460f4705375980d27154c51e24"}, + {file = "torch-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:287a7677df844bf2c4425698fd6d9434065211219cd7fd96000ed981c4d92288"}, + {file = "torch-1.8.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1b58f70c150e066bcd7401a3bdfad661a04244817a5dac9990b5367523887d3f"}, + {file = "torch-1.8.0-cp38-none-macosx_11_1_arm64.whl", hash = "sha256:923856c2e6e53d5a747d83ff40faadd791d27cea2fd881b8d6990ea269f47572"}, + {file = "torch-1.8.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:2318fac860ae73dc6486c0de2223674d9ef6139fc75f157af2bf8dce4fca5524"}, + {file = "torch-1.8.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:05b114cb793816cd140794d5874f463972cb639f3b55d3a060f21fd066f5b629"}, + {file = "torch-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:7438431e03af793979cb1a9b5dd9f399b38461748e9f21f60e36149ee215d751"}, + {file = "torch-1.8.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:d98d167994d2e30df61a98eaca1684c50761f096d7f76c0c99789ac8cea50b55"}, ] torch-summary = [ {file = "torch-summary-1.4.3.tar.gz", hash = "sha256:2dcbc1dfd07dca9f4080bcacdaf90db3f2fc28efee348c8fba9033039b0e8c82"}, {file = "torch_summary-1.4.3-py3-none-any.whl", hash = "sha256:a0a76916bd11d054fd3863dc7c474971922badfbc13d6404f9eddd297041f094"}, ] torchvision = [ - {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"}, - {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"}, - {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"}, - {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"}, - {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"}, - {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"}, - {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"}, - {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"}, + {file = "torchvision-0.9.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:63052147c776d9f93385410c1d5a791386eb0cb5e1b93c7feac686f8dbe6eb06"}, + {file = "torchvision-0.9.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:874714f30822d4c1160071dac004d48ae641bfdccccbb497098c86f6589ec0f1"}, + {file = "torchvision-0.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:b9f71f62725776495071b875494af86615f225b1a40902f5df452da5cfde0510"}, + {file = "torchvision-0.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24b505dbcf3cb8da49d4b1447543c1021b699c84fc3701523101b62ee4adf097"}, + {file = "torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:d90750ae76a0cad8ffb6b509b30412dcd102d27d5f34f7184b289b6687de580e"}, + {file = "torchvision-0.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8791da742c24344646a4ac36adee9327491f7fff7607dffe352402b5bf25ea21"}, + {file = "torchvision-0.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fa302f6e8fe33a8d5c6649e659655c0427eee662fe22ce69eb56fa402b520c26"}, + {file = "torchvision-0.9.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8fce78a59959f4bb4780a78c2277d617e44da7bc270bc449ff403187f6b587fc"}, + {file = "torchvision-0.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:2252bc63fcccb27785726dd9d0d9a97432657a5d139390bf93cd6bdf227a4401"}, + {file = "torchvision-0.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:421bda7131f3c0eae2260f10174ac3c49e54183b33acb927b4b572f4cd90066d"}, + {file = "torchvision-0.9.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d85d405e8cf694c1f85da7f0496ea69dd4f8d8dafbdad1e29bcdc4c621fc5cf0"}, + {file = "torchvision-0.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:b03275f351feffaf7450d234ffb57cce26ff5e696d01ef5f543de205f18849a9"}, ] tornado = [ {file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"}, diff --git a/pyproject.toml b/pyproject.toml index ef75edf..33d539e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ sphinx_rtd_theme = "^0.4.3" boltons = "^20.1.0" h5py = "^3.2.1" toml = "^0.10.1" -torch = "^1.7.0" -torchvision = "^0.8.1" +torch = "^1.8.0" +torchvision = "^0.9.0" loguru = "^0.5.0" matplotlib = "^3.2.1" tqdm = "^4.46.1" @@ -64,12 +64,13 @@ gpustat = "^0.6.0" redlock-py = "^1.0.8" wandb = "^0.10.11" graphviz = "^0.16" +scipy = "^1.6.1" [tool.coverage.report] fail_under = 50 [tool.poetry.scripts] -download-emnist = "text_recognizer.datasets.util:download_emnist" +download-emnist = "text_recognizer.datasets.emnist:download_emnist" download-iam = "text_recognizer.datasets.iam_dataset:main" create-emnist-support-files = "text_recognizer.tests.support.create_emnist_support_files:create_emnist_support_files" create-emnist-lines-datasets = "text_recognizer.datasets.emnist_lines_dataset:create_datasets" @@ -78,7 +79,6 @@ prepare-experiments = "training.prepare_experiments:run_cli" run-experiment = "training.run_experiment:run_cli" - [build-system] requires = ["poetry>=0.12"] build-backend = "poetry.masonry.api" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..18ff212 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test modules for the text text recognizer.""" diff --git a/tests/support/__init__.py b/tests/support/__init__.py new file mode 100644 index 0000000..a265ede --- /dev/null +++ b/tests/support/__init__.py @@ -0,0 +1,2 @@ +"""Support file modules.""" +from .create_emnist_support_files import create_emnist_support_files diff --git a/tests/support/create_emnist_lines_support_files.py b/tests/support/create_emnist_lines_support_files.py new file mode 100644 index 0000000..9abe143 --- /dev/null +++ b/tests/support/create_emnist_lines_support_files.py @@ -0,0 +1,51 @@ +"""Module for creating EMNIST Lines test support files.""" +# flake8: noqa: S106 + +from pathlib import Path +import shutil + +import numpy as np + +from text_recognizer.datasets import EmnistLinesDataset +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines" + + +def create_emnist_lines_support_files() -> None: + """Create EMNIST Lines test images.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + # TODO: maybe have to add args to dataset. + dataset = EmnistLinesDataset( + init_token="", + pad_token="_", + eos_token="", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "", "pad_token": "_", "eos_token": ""}, + } + ], + ) # nosec: S106 + dataset.load_or_generate_data() + + for index in [5, 7, 9]: + image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + print(image.sum(), image.dtype) + + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token + ) + print(label) + image = image.numpy() + util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_lines_support_files() diff --git a/tests/support/create_emnist_support_files.py b/tests/support/create_emnist_support_files.py new file mode 100644 index 0000000..f9ff030 --- /dev/null +++ b/tests/support/create_emnist_support_files.py @@ -0,0 +1,30 @@ +"""Module for creating EMNIST test support files.""" +from pathlib import Path +import shutil + +from text_recognizer.datasets import EmnistDataset +from text_recognizer.util import write_image + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" + + +def create_emnist_support_files() -> None: + """Create support images for test of CharacterPredictor class.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + dataset = EmnistDataset(train=False) + dataset.load_or_generate_data() + + for index in [5, 7, 9]: + image, label = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + image = image.numpy() + label = dataset.mapper(int(label)) + print(index, label) + write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_support_files() diff --git a/tests/support/create_iam_lines_support_files.py b/tests/support/create_iam_lines_support_files.py new file mode 100644 index 0000000..50f9e3d --- /dev/null +++ b/tests/support/create_iam_lines_support_files.py @@ -0,0 +1,50 @@ +"""Module for creating IAM Lines test support files.""" +# flake8: noqa +from pathlib import Path +import shutil + +import numpy as np + +from text_recognizer.datasets import IamLinesDataset +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines" + + +def create_emnist_lines_support_files() -> None: + """Create IAM Lines test images.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + # TODO: maybe have to add args to dataset. + dataset = IamLinesDataset( + init_token="", + pad_token="_", + eos_token="", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "", "pad_token": "_", "eos_token": ""}, + } + ], + ) + dataset.load_or_generate_data() + + for index in [0, 1, 3]: + image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + print(image.sum(), image.dtype) + + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token + ) + print(label) + image = image.numpy() + util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_lines_support_files() diff --git a/tests/support/emnist_lines/Knox Ky.png b/tests/support/emnist_lines/Knox Ky.png new file mode 100644 index 0000000..b7d0618 Binary files /dev/null and b/tests/support/emnist_lines/Knox Ky.png differ diff --git a/tests/support/emnist_lines/ancillary beliefs and.png b/tests/support/emnist_lines/ancillary beliefs and.png new file mode 100644 index 0000000..14a8cf3 Binary files /dev/null and b/tests/support/emnist_lines/ancillary beliefs and.png differ diff --git a/tests/support/emnist_lines/they.png b/tests/support/emnist_lines/they.png new file mode 100644 index 0000000..7f05951 Binary files /dev/null and b/tests/support/emnist_lines/they.png differ diff --git a/tests/support/iam_lines/He rose from his breakfast-nook bench.png b/tests/support/iam_lines/He rose from his breakfast-nook bench.png new file mode 100644 index 0000000..6eeb642 Binary files /dev/null and b/tests/support/iam_lines/He rose from his breakfast-nook bench.png differ diff --git a/tests/support/iam_lines/and came into the livingroom, where.png b/tests/support/iam_lines/and came into the livingroom, where.png new file mode 100644 index 0000000..4974cf8 Binary files /dev/null and b/tests/support/iam_lines/and came into the livingroom, where.png differ diff --git a/tests/support/iam_lines/his entrance. He came, almost falling.png b/tests/support/iam_lines/his entrance. He came, almost falling.png new file mode 100644 index 0000000..a731245 Binary files /dev/null and b/tests/support/iam_lines/his entrance. He came, almost falling.png differ diff --git a/tests/support/iam_paragraphs/a01-000u.jpg b/tests/support/iam_paragraphs/a01-000u.jpg new file mode 100644 index 0000000..d9753b6 Binary files /dev/null and b/tests/support/iam_paragraphs/a01-000u.jpg differ diff --git a/tests/test_character_predictor.py b/tests/test_character_predictor.py new file mode 100644 index 0000000..01bda78 --- /dev/null +++ b/tests/test_character_predictor.py @@ -0,0 +1,31 @@ +"""Test for CharacterPredictor class.""" +import importlib +import os +from pathlib import Path +import unittest + +from loguru import logger + +from text_recognizer.character_predictor import CharacterPredictor +from text_recognizer.networks import MLP + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestCharacterPredictor(unittest.TestCase): + """Tests for the CharacterPredictor class.""" + + def test_filename(self) -> None: + """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" + network_fn_ = MLP + predictor = CharacterPredictor(network_fn=network_fn_) + + for filename in SUPPORT_DIRNAME.glob("*.png"): + pred, conf = predictor.predict(str(filename)) + logger.info( + f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}" + ) + self.assertEqual(pred, filename.stem) + self.assertGreater(conf, 0.7) diff --git a/tests/test_line_predictor.py b/tests/test_line_predictor.py new file mode 100644 index 0000000..eede4d4 --- /dev/null +++ b/tests/test_line_predictor.py @@ -0,0 +1,35 @@ +"""Tests for LinePredictor.""" +import os +from pathlib import Path +import unittest + + +import editdistance +import numpy as np + +from text_recognizer.datasets import IamLinesDataset +from text_recognizer.line_predictor import LinePredictor +import text_recognizer.util as util + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestEmnistLinePredictor(unittest.TestCase): + """Test LinePredictor class on the EmnistLines dataset.""" + + def test_filename(self) -> None: + """Test that LinePredictor correctly predicts on single images, for several test images.""" + predictor = LinePredictor( + dataset="EmnistLineDataset", network_fn="CNNTransformer" + ) + + for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"): + pred, conf = predictor.predict(str(filename)) + true = str(filename.stem) + edit_distance = editdistance.eval(pred, true) / len(pred) + print( + f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}' + ) + self.assertLess(edit_distance, 0.2) diff --git a/tests/test_paragraph_text_recognizer.py b/tests/test_paragraph_text_recognizer.py new file mode 100644 index 0000000..3e280b9 --- /dev/null +++ b/tests/test_paragraph_text_recognizer.py @@ -0,0 +1,37 @@ +"""Test for ParagraphTextRecognizer class.""" +import os +from pathlib import Path +import unittest + +from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" + +# Prevent using GPU. +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestParagraphTextRecognizor(unittest.TestCase): + """Test that it can take non-square images of max dimension larger than 256px.""" + + def test_filename(self) -> None: + """Test model on support image.""" + line_predictor_args = { + "dataset": "EmnistLineDataset", + "network_fn": "CNNTransformer", + } + line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} + model = ParagraphTextRecognizor( + line_predictor_args=line_predictor_args, + line_detector_args=line_detector_args, + ) + num_text_lines_by_name = {"a01-000u-cropped": 7} + for filename in (SUPPORT_DIRNAME).glob("*.jpg"): + full_image = util.read_image(str(filename), grayscale=True) + predicted_text, line_region_crops = model.predict(full_image) + print(predicted_text) + self.assertTrue( + len(line_region_crops), num_text_lines_by_name[filename.stem] + ) diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py index a6c1c59..2727b20 100644 --- a/text_recognizer/datasets/__init__.py +++ b/text_recognizer/datasets/__init__.py @@ -1,39 +1 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataset -from .emnist_lines_dataset import ( - construct_image_from_string, - EmnistLinesDataset, - get_samples_by_character, -) -from .iam_dataset import IamDataset -from .iam_lines_dataset import IamLinesDataset -from .iam_paragraphs_dataset import IamParagraphsDataset -from .iam_preprocessor import load_metadata, Preprocessor -from .transforms import AddTokens, Transpose -from .util import ( - _download_raw_dataset, - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -__all__ = [ - "_download_raw_dataset", - "AddTokens", - "compute_sha256", - "construct_image_from_string", - "DATA_DIRNAME", - "download_url", - "EmnistDataset", - "EmnistMapper", - "EmnistLinesDataset", - "get_samples_by_character", - "load_metadata", - "IamDataset", - "IamLinesDataset", - "IamParagraphsDataset", - "Preprocessor", - "Transpose", -] diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py index 09a0a43..830b39b 100644 --- a/text_recognizer/datasets/base_data_module.py +++ b/text_recognizer/datasets/base_data_module.py @@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None: class BaseDataModule(pl.LightningDataModule): """Base PyTorch Lightning DataModule.""" - + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: super().__init__() self.batch_size = batch_size @@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule): def config(self) -> Dict: """Return important settings of the dataset.""" - return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping} + return { + "input_dim": self.dims, + "output_dims": self.output_dims, + "mapping": self.mapping, + } def prepare_data(self) -> None: """Prepare data for training.""" pass - def setup(self, stage: Any = None) -> None: + def setup(self, stage: str = None) -> None: """Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and @@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule): self.data_val = None self.data_test = None - def train_dataloader(self) -> DataLoader: """Retun DataLoader for train data.""" - return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def val_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_val, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def test_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) - + return DataLoader( + self.data_test, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py index 7322d7f..a004b8d 100644 --- a/text_recognizer/datasets/base_dataset.py +++ b/text_recognizer/datasets/base_dataset.py @@ -17,12 +17,14 @@ class BaseDataset(Dataset): target_transform (Callable): Fucntion that takes a target and applies target transforms. """ - def __init__(self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length.") self.data = data @@ -30,11 +32,10 @@ class BaseDataset(Dataset): self.transform = transform self.target_transform = target_transform - def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a datum and its target, after processing by transforms. @@ -56,7 +57,9 @@ class BaseDataset(Dataset): return datum, target -def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor: +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: """ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, and padded wiht the

token. diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py deleted file mode 100644 index e794605..0000000 --- a/text_recognizer/datasets/dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Abstract dataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.utils import data -from torchvision.transforms import ToTensor - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.datasets.util import EmnistMapper - - -class Dataset(data.Dataset): - """Abstract class for with common methods for all datasets.""" - - def __init__( - self, - train: bool, - subsample_fraction: float = None, - transform: Optional[List[Dict]] = None, - target_transform: Optional[List[Dict]] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Initialization of Dataset class. - - Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. - target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): Only use lower case letters. Defaults to False. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.train = train - self.split = "train" if self.train else "test" - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower - ) - self._input_shape = self._mapper.input_shape - self._output_shape = self._mapper._num_classes - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = self._configure_transform(transform) - self.target_transform = self._configure_target_transform(target_transform) - - self._data = None - self._targets = None - - def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: - transform_list = [] - if transform is not None: - for t in transform: - t_type = t["type"] - t_args = t["args"] or {} - transform_list.append(getattr(transforms, t_type)(**t_args)) - else: - transform_list.append(ToTensor()) - return transforms.Compose(transform_list) - - def _configure_target_transform( - self, target_transform: List[Dict] - ) -> transforms.Compose: - target_transform_list = [torch.tensor] - if target_transform is not None: - for t in target_transform: - t_type = t["type"] - t_args = t["args"] or {} - target_transform_list.append(getattr(transforms, t_type)(**t_args)) - return transforms.Compose(target_transform_list) - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self._data = self.data[:num_subsample] - self._targets = self.targets[:num_subsample] - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - raise NotImplementedError - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. - - Raises: - NotImplementedError: If the method is not implemented in child class. - - """ - raise NotImplementedError - - def __repr__(self) -> str: - """Returns information about the dataset.""" - raise NotImplementedError diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py index 7a2cab8..e3dc68c 100644 --- a/text_recognizer/datasets/download_utils.py +++ b/text_recognizer/datasets/download_utils.py @@ -63,11 +63,11 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: if filename.exists(): return logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") - _download_url(metadata["url"], filename) + _download_url(metadata["url"], filename) logger.info("Computing the SHA-256...") sha256 = _compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) return filename diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py index e99dbfd..7c208c4 100644 --- a/text_recognizer/datasets/emnist.py +++ b/text_recognizer/datasets/emnist.py @@ -15,20 +15,23 @@ from torch.utils.data import random_split from torchvision import transforms from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info +from text_recognizer.datasets.base_data_module import ( + BaseDataModule, + load_and_print_info, +) from text_recognizer.datasets.download_utils import download_dataset SEED = 4711 NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True +SAMPLE_TO_BALANCE = True RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" class EMNIST(BaseDataModule): @@ -41,7 +44,9 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None: + def __init__( + self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 + ) -> None: super().__init__(batch_size, num_workers) if not ESSENTIALS_FILENAME.exists(): _download_and_process_emnist() @@ -64,20 +69,21 @@ class EMNIST(BaseDataModule): def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_train"][:] - targets = f["y_train"][:] - - dataset_train = BaseDataset(data, targets, transform=self.transform) + self.x_train = f["x_train"][:] + self.y_train = f["y_train"][:] + + dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform) train_size = int(self.train_fraction * len(dataset_train)) val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator()) + self.data_train, self.data_val = random_split( + dataset_train, [train_size, val_size], generator=torch.Generator() + ) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_test"][:] - targets = f["y_test"][:] - self.data_test = BaseDataset(data, targets, transform=self.transform) - + self.x_test = f["x_test"][:] + self.y_test = f["y_test"][:] + self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self) -> str: basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" @@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: logger.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") - x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_train = ( + data["dataset"]["train"][0, 0]["images"][0, 0] + .reshape(-1, 28, 28) + .swapaxes(1, 2) + ) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_test = ( + data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + ) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: @@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: @@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda all_sampled_indices.append(sampled_indices) indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] - y_sampled= y[indices] + y_sampled = y[indices] return x_sampled, y_sampled @@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset. iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] # Also add special tokens for: # - CTC blank token at index 0 @@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: return ["", "", "", "

", *characters, *iam_characters] -if __name__ == "__main__": - load_print_info(EMNIST) +def download_emnist() -> None: + """Download dataset from internet, if it does not exists, and displays info.""" + load_and_print_info(EMNIST) diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..100b36a --- /dev/null +++ b/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} \ No newline at end of file diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py new file mode 100644 index 0000000..ae23feb --- /dev/null +++ b/text_recognizer/datasets/emnist_lines.py @@ -0,0 +1,184 @@ +"""Dataset of generated text from EMNIST characters.""" +from collections import defaultdict +from pathlib import Path +from typing import Dict, Sequence + +import h5py +from loguru import logger +import numpy as np +import torch +from torchvision import transforms + +from text_recognizer.datasets.base_dataset import BaseDataset +from text_recognizer.datasets.base_data_module import BaseDataModule +from text_recognizer.datasets.emnist import EMNIST +from text_recognizer.datasets.sentence_generator import SentenceGenerator + + +DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" +) + +SEED = 4711 +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 +IMAGE_X_PADDING = 28 +MAX_OUTPUT_LENGTH = 89 # Same as IAMLines + + +class EMNISTLines(BaseDataModule): + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + + def __init__( + self, + augment: bool = True, + batch_size: int = 128, + num_workers: int = 0, + max_length: int = 32, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__(batch_size, num_workers) + + self.augment = augment + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + self.emnist = EMNIST() + self.mapping = self.emnist.mapping + max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING + + if max_width <= IMAGE_WIDTH: + raise ValueError("max_width greater than IMAGE_WIDTH") + + self.dims = ( + self.emnist.dims[0], + self.emnist.dims[1], + self.emnist.dims[2] * self.max_length, + ) + + if self.max_length <= MAX_OUTPUT_LENGTH: + raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") + + self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.data_train = None + self.data_val = None + self.data_test = None + + @property + def data_filename(self) -> Path: + """Return name of dataset.""" + return ( + DATA_DIRNAME + / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" + ) + + def prepare_data(self) -> None: + if self.data_filename.exists(): + return + np.random.seed(SEED) + self._generate_data("train") + self._generate_data("val") + self._generate_data("test") + + def setup(self, stage: str = None) -> None: + logger.info("EMNISTLinesDataset loading data from HDF5...") + if stage == "fit" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_train = f["x_train"][:] + y_train = torch.LongTensor(f["y_train"][:]) + x_val = f["x_val"][:] + y_val = torch.LongTensor(f["y_val"][:]) + + self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment)) + self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment)) + + if stage == "test" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_test = f["x_test"][:] + y_test = torch.LongTensor(f["y_test"][:]) + + self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False)) + + def __repr__(self) -> str: + """Return str about dataset.""" + basic = ( + "EMNISTLines2 Dataset\n" # pylint: disable=no-member + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {len(self.mapping)}\n" + f"Dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + ) + return basic + data + + def _generate_data(self, split: str) -> None: + logger.info(f"EMNISTLines generating data for {split}...") + sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token + + emnist = self.emnist + emnist.prepare_data() + emnist.setup() + + if split == "train": + samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + num = self.num_train + elif split == "val": + samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + num = self.num_val + elif split == "test": + samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) + num = self.num_test + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "w") as f: + x, y = _create_dataset_of_images( + num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims + ) + y = _convert_strings_to_labels( + y, + emnist.inverse_mapping, + length=MAX_OUTPUT_LENGTH + ) + f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") + f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") + +def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict: + samples_by_char = defaultdict(list) + for sample, label in zip(samples, labels): + samples_by_char[mapping[label]].append(sample) + return samples_by_char + + +def _construct_image_from_string(): + pass + + +def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): + pass + + +def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]: + images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + labels = [] + for n in range(num_samples): + label = sentence_generator.generate() + crop = _construct_image_from_string() diff --git a/text_recognizer/tests/__init__.py b/text_recognizer/tests/__init__.py deleted file mode 100644 index 18ff212..0000000 --- a/text_recognizer/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test modules for the text text recognizer.""" diff --git a/text_recognizer/tests/support/__init__.py b/text_recognizer/tests/support/__init__.py deleted file mode 100644 index a265ede..0000000 --- a/text_recognizer/tests/support/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Support file modules.""" -from .create_emnist_support_files import create_emnist_support_files diff --git a/text_recognizer/tests/support/create_emnist_lines_support_files.py b/text_recognizer/tests/support/create_emnist_lines_support_files.py deleted file mode 100644 index 9abe143..0000000 --- a/text_recognizer/tests/support/create_emnist_lines_support_files.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Module for creating EMNIST Lines test support files.""" -# flake8: noqa: S106 - -from pathlib import Path -import shutil - -import numpy as np - -from text_recognizer.datasets import EmnistLinesDataset -import text_recognizer.util as util - - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines" - - -def create_emnist_lines_support_files() -> None: - """Create EMNIST Lines test images.""" - shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) - SUPPORT_DIRNAME.mkdir() - - # TODO: maybe have to add args to dataset. - dataset = EmnistLinesDataset( - init_token="", - pad_token="_", - eos_token="", - transform=[{"type": "ToTensor", "args": {}}], - target_transform=[ - { - "type": "AddTokens", - "args": {"init_token": "", "pad_token": "_", "eos_token": ""}, - } - ], - ) # nosec: S106 - dataset.load_or_generate_data() - - for index in [5, 7, 9]: - image, target = dataset[index] - if len(image.shape) == 3: - image = image.squeeze(0) - print(image.sum(), image.dtype) - - label = "".join(dataset.mapper(label) for label in target[1:]).strip( - dataset.mapper.pad_token - ) - print(label) - image = image.numpy() - util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) - - -if __name__ == "__main__": - create_emnist_lines_support_files() diff --git a/text_recognizer/tests/support/create_emnist_support_files.py b/text_recognizer/tests/support/create_emnist_support_files.py deleted file mode 100644 index f9ff030..0000000 --- a/text_recognizer/tests/support/create_emnist_support_files.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Module for creating EMNIST test support files.""" -from pathlib import Path -import shutil - -from text_recognizer.datasets import EmnistDataset -from text_recognizer.util import write_image - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" - - -def create_emnist_support_files() -> None: - """Create support images for test of CharacterPredictor class.""" - shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) - SUPPORT_DIRNAME.mkdir() - - dataset = EmnistDataset(train=False) - dataset.load_or_generate_data() - - for index in [5, 7, 9]: - image, label = dataset[index] - if len(image.shape) == 3: - image = image.squeeze(0) - image = image.numpy() - label = dataset.mapper(int(label)) - print(index, label) - write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) - - -if __name__ == "__main__": - create_emnist_support_files() diff --git a/text_recognizer/tests/support/create_iam_lines_support_files.py b/text_recognizer/tests/support/create_iam_lines_support_files.py deleted file mode 100644 index 50f9e3d..0000000 --- a/text_recognizer/tests/support/create_iam_lines_support_files.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Module for creating IAM Lines test support files.""" -# flake8: noqa -from pathlib import Path -import shutil - -import numpy as np - -from text_recognizer.datasets import IamLinesDataset -import text_recognizer.util as util - - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines" - - -def create_emnist_lines_support_files() -> None: - """Create IAM Lines test images.""" - shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) - SUPPORT_DIRNAME.mkdir() - - # TODO: maybe have to add args to dataset. - dataset = IamLinesDataset( - init_token="", - pad_token="_", - eos_token="", - transform=[{"type": "ToTensor", "args": {}}], - target_transform=[ - { - "type": "AddTokens", - "args": {"init_token": "", "pad_token": "_", "eos_token": ""}, - } - ], - ) - dataset.load_or_generate_data() - - for index in [0, 1, 3]: - image, target = dataset[index] - if len(image.shape) == 3: - image = image.squeeze(0) - print(image.sum(), image.dtype) - - label = "".join(dataset.mapper(label) for label in target[1:]).strip( - dataset.mapper.pad_token - ) - print(label) - image = image.numpy() - util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) - - -if __name__ == "__main__": - create_emnist_lines_support_files() diff --git a/text_recognizer/tests/support/emnist_lines/Knox Ky.png b/text_recognizer/tests/support/emnist_lines/Knox Ky.png deleted file mode 100644 index b7d0618..0000000 Binary files a/text_recognizer/tests/support/emnist_lines/Knox Ky.png and /dev/null differ diff --git a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png deleted file mode 100644 index 14a8cf3..0000000 Binary files a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png and /dev/null differ diff --git a/text_recognizer/tests/support/emnist_lines/they.png b/text_recognizer/tests/support/emnist_lines/they.png deleted file mode 100644 index 7f05951..0000000 Binary files a/text_recognizer/tests/support/emnist_lines/they.png and /dev/null differ diff --git a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png deleted file mode 100644 index 6eeb642..0000000 Binary files a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png and /dev/null differ diff --git a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png deleted file mode 100644 index 4974cf8..0000000 Binary files a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png and /dev/null differ diff --git a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png deleted file mode 100644 index a731245..0000000 Binary files a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png and /dev/null differ diff --git a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg deleted file mode 100644 index d9753b6..0000000 Binary files a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg and /dev/null differ diff --git a/text_recognizer/tests/test_character_predictor.py b/text_recognizer/tests/test_character_predictor.py deleted file mode 100644 index 01bda78..0000000 --- a/text_recognizer/tests/test_character_predictor.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Test for CharacterPredictor class.""" -import importlib -import os -from pathlib import Path -import unittest - -from loguru import logger - -from text_recognizer.character_predictor import CharacterPredictor -from text_recognizer.networks import MLP - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist" - -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -class TestCharacterPredictor(unittest.TestCase): - """Tests for the CharacterPredictor class.""" - - def test_filename(self) -> None: - """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" - network_fn_ = MLP - predictor = CharacterPredictor(network_fn=network_fn_) - - for filename in SUPPORT_DIRNAME.glob("*.png"): - pred, conf = predictor.predict(str(filename)) - logger.info( - f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}" - ) - self.assertEqual(pred, filename.stem) - self.assertGreater(conf, 0.7) diff --git a/text_recognizer/tests/test_line_predictor.py b/text_recognizer/tests/test_line_predictor.py deleted file mode 100644 index eede4d4..0000000 --- a/text_recognizer/tests/test_line_predictor.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Tests for LinePredictor.""" -import os -from pathlib import Path -import unittest - - -import editdistance -import numpy as np - -from text_recognizer.datasets import IamLinesDataset -from text_recognizer.line_predictor import LinePredictor -import text_recognizer.util as util - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" - -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -class TestEmnistLinePredictor(unittest.TestCase): - """Test LinePredictor class on the EmnistLines dataset.""" - - def test_filename(self) -> None: - """Test that LinePredictor correctly predicts on single images, for several test images.""" - predictor = LinePredictor( - dataset="EmnistLineDataset", network_fn="CNNTransformer" - ) - - for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"): - pred, conf = predictor.predict(str(filename)) - true = str(filename.stem) - edit_distance = editdistance.eval(pred, true) / len(pred) - print( - f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}' - ) - self.assertLess(edit_distance, 0.2) diff --git a/text_recognizer/tests/test_paragraph_text_recognizer.py b/text_recognizer/tests/test_paragraph_text_recognizer.py deleted file mode 100644 index 3e280b9..0000000 --- a/text_recognizer/tests/test_paragraph_text_recognizer.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Test for ParagraphTextRecognizer class.""" -import os -from pathlib import Path -import unittest - -from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor -import text_recognizer.util as util - - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" - -# Prevent using GPU. -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -class TestParagraphTextRecognizor(unittest.TestCase): - """Test that it can take non-square images of max dimension larger than 256px.""" - - def test_filename(self) -> None: - """Test model on support image.""" - line_predictor_args = { - "dataset": "EmnistLineDataset", - "network_fn": "CNNTransformer", - } - line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} - model = ParagraphTextRecognizor( - line_predictor_args=line_predictor_args, - line_detector_args=line_detector_args, - ) - num_text_lines_by_name = {"a01-000u-cropped": 7} - for filename in (SUPPORT_DIRNAME).glob("*.jpg"): - full_image = util.read_image(str(filename), grayscale=True) - predicted_text, line_region_crops = model.predict(full_image) - print(predicted_text) - self.assertTrue( - len(line_region_crops), num_text_lines_by_name[filename.stem] - ) -- cgit v1.2.3-70-g09d2