summaryrefslogtreecommitdiff
path: root/src/text_recognizer/tests
diff options
context:
space:
mode:
authoraktersnurra <grydholm@kth.se>2020-11-15 13:48:30 +0100
committeraktersnurra <grydholm@kth.se>2020-11-15 13:48:30 +0100
commit85372e9dd47441b9eb5822bb37113e51bc9fa72d (patch)
tree07ed401392f3fc2174e805dd562be411781a6908 /src/text_recognizer/tests
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Added support files for lines datasets.
Diffstat (limited to 'src/text_recognizer/tests')
-rw-r--r--src/text_recognizer/tests/support/create_emnist_lines_support_files.py38
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py6
-rw-r--r--src/text_recognizer/tests/support/create_iam_lines_support_files.py38
3 files changed, 78 insertions, 4 deletions
diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
new file mode 100644
index 0000000..b200ff5
--- /dev/null
+++ b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
@@ -0,0 +1,38 @@
+"""Module for creating EMNIST Lines test support files."""
+from pathlib import Path
+import shutil
+
+from text_recognizer.datasets import EmnistLinesDataset
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines"
+
+
+def create_emnist_lines_support_files() -> None:
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ # TODO: maybe have to add args to dataset.
+ dataset = EmnistLinesDataset()
+ dataset.load_or_generate_data()
+
+ for index in [0, 1, 3]:
+ image, target = dataset[index]
+ print(image.sum(), image.dtype)
+
+ label = (
+ "".join(
+ dataset.mapper[label]
+ for label in np.argmax(target[1:], dim=-1).flatten()
+ )
+ .stip()
+ .strip(self.mapper.pad_token)
+ )
+
+ print(label)
+ util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_lines_support_files()
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
index c04860d..f9ff030 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -2,8 +2,7 @@
from pathlib import Path
import shutil
-from text_recognizer.datasets.emnist_dataset import EmnistDataset
-from text_recognizer.datasets.util import EmnistMapper
+from text_recognizer.datasets import EmnistDataset
from text_recognizer.util import write_image
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
@@ -16,14 +15,13 @@ def create_emnist_support_files() -> None:
dataset = EmnistDataset(train=False)
dataset.load_or_generate_data()
- mapping = EmnistMapper()
for index in [5, 7, 9]:
image, label = dataset[index]
if len(image.shape) == 3:
image = image.squeeze(0)
image = image.numpy()
- label = mapping(int(label))
+ label = dataset.mapper(int(label))
print(index, label)
write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/src/text_recognizer/tests/support/create_iam_lines_support_files.py
new file mode 100644
index 0000000..15d0e4e
--- /dev/null
+++ b/src/text_recognizer/tests/support/create_iam_lines_support_files.py
@@ -0,0 +1,38 @@
+"""Module for creating IAM Lines test support files."""
+from pathlib import Path
+import shutil
+
+from text_recognizer.datasets import IamLinesDataset
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines"
+
+
+def create_emnist_lines_support_files() -> None:
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ # TODO: maybe have to add args to dataset.
+ dataset = IamLinesDataset()
+ dataset.load_or_generate_data()
+
+ for index in [0, 1, 3]:
+ image, target = dataset[index]
+ print(image.sum(), image.dtype)
+
+ label = (
+ "".join(
+ dataset.mapper[label]
+ for label in np.argmax(target[1:], dim=-1).flatten()
+ )
+ .stip()
+ .strip(self.mapper.pad_token)
+ )
+
+ print(label)
+ util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_lines_support_files()