summaryrefslogtreecommitdiff
path: root/src/text_recognizer/tests
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 16:21:40 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 16:21:40 +0100
commitf9a222dc1dd351147ed9aec970097b9f97995403 (patch)
tree834e458890d450d8f179551ec8ab1184b7061e54 /src/text_recognizer/tests
parent9895780f8bb66b6fbdd16b8501775e0eb4890a64 (diff)
parent85372e9dd47441b9eb5822bb37113e51bc9fa72d (diff)
Merge branch 'master' of github.com:aktersnurra/text-recognizer into HEAD
Diffstat (limited to 'src/text_recognizer/tests')
-rw-r--r--src/text_recognizer/tests/support/create_emnist_lines_support_files.py38
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py6
-rw-r--r--src/text_recognizer/tests/support/create_iam_lines_support_files.py38
3 files changed, 78 insertions, 4 deletions
diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
new file mode 100644
index 0000000..b200ff5
--- /dev/null
+++ b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
@@ -0,0 +1,38 @@
+"""Module for creating EMNIST Lines test support files."""
+from pathlib import Path
+import shutil
+
+from text_recognizer.datasets import EmnistLinesDataset
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines"
+
+
+def create_emnist_lines_support_files() -> None:
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ # TODO: maybe have to add args to dataset.
+ dataset = EmnistLinesDataset()
+ dataset.load_or_generate_data()
+
+ for index in [0, 1, 3]:
+ image, target = dataset[index]
+ print(image.sum(), image.dtype)
+
+ label = (
+ "".join(
+ dataset.mapper[label]
+ for label in np.argmax(target[1:], dim=-1).flatten()
+ )
+ .stip()
+ .strip(self.mapper.pad_token)
+ )
+
+ print(label)
+ util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_lines_support_files()
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
index c04860d..f9ff030 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -2,8 +2,7 @@
from pathlib import Path
import shutil
-from text_recognizer.datasets.emnist_dataset import EmnistDataset
-from text_recognizer.datasets.util import EmnistMapper
+from text_recognizer.datasets import EmnistDataset
from text_recognizer.util import write_image
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
@@ -16,14 +15,13 @@ def create_emnist_support_files() -> None:
dataset = EmnistDataset(train=False)
dataset.load_or_generate_data()
- mapping = EmnistMapper()
for index in [5, 7, 9]:
image, label = dataset[index]
if len(image.shape) == 3:
image = image.squeeze(0)
image = image.numpy()
- label = mapping(int(label))
+ label = dataset.mapper(int(label))
print(index, label)
write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/src/text_recognizer/tests/support/create_iam_lines_support_files.py
new file mode 100644
index 0000000..15d0e4e
--- /dev/null
+++ b/src/text_recognizer/tests/support/create_iam_lines_support_files.py
@@ -0,0 +1,38 @@
+"""Module for creating IAM Lines test support files."""
+from pathlib import Path
+import shutil
+
+from text_recognizer.datasets import IamLinesDataset
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines"
+
+
+def create_emnist_lines_support_files() -> None:
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ # TODO: maybe have to add args to dataset.
+ dataset = IamLinesDataset()
+ dataset.load_or_generate_data()
+
+ for index in [0, 1, 3]:
+ image, target = dataset[index]
+ print(image.sum(), image.dtype)
+
+ label = (
+ "".join(
+ dataset.mapper[label]
+ for label in np.argmax(target[1:], dim=-1).flatten()
+ )
+ .stip()
+ .strip(self.mapper.pad_token)
+ )
+
+ print(label)
+ util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_lines_support_files()