summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-28 22:02:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-28 22:02:24 +0200
commit46a1472d33d3a4180798492e819f2ec02bc3b1a3 (patch)
tree22322ed0d8f9f803966ea745ec5bb8c759f8db64 /text_recognizer/data/emnist_lines.py
parent8248f173132dfb7e47ec62b08e9235990c8626e3 (diff)
Add refactor of iam lines
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py32
1 files changed, 15 insertions, 17 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 6c14add..72665d0 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,12 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, Tuple, Sequence
+from typing import Callable, Dict, Tuple
import h5py
from loguru import logger
import numpy as np
-from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
@@ -58,6 +57,7 @@ class EMNISTLines(BaseDataModule):
self.num_test = num_test
self.emnist = EMNIST()
+ # TODO: fix mapping
self.mapping = self.emnist.mapping
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
@@ -66,32 +66,28 @@ class EMNISTLines(BaseDataModule):
if max_width >= IMAGE_WIDTH:
raise ValueError(
- f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
- )
+ f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
+ )
- self.dims = (
- self.emnist.dims[0],
- IMAGE_HEIGHT,
- IMAGE_WIDTH
- )
+ self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH)
if self.max_length >= MAX_OUTPUT_LENGTH:
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train = None
- self.data_val = None
- self.data_test = None
+ self.data_train: BaseDataset = None
+ self.data_val: BaseDataset = None
+ self.data_test: BaseDataset = None
@property
def data_filename(self) -> Path:
"""Return name of dataset."""
- return (
- DATA_DIRNAME / (f"ml_{self.max_length}_"
+ return DATA_DIRNAME / (
+ f"ml_{self.max_length}_"
f"o{self.min_overlap:f}_{self.max_overlap:f}_"
f"ntr{self.num_train}_"
f"ntv{self.num_val}_"
- f"nte{self.num_test}.h5")
+ f"nte{self.num_test}.h5"
)
def prepare_data(self) -> None:
@@ -144,7 +140,10 @@ class EMNISTLines(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
data = (
- f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ "Train/val/test sizes: "
+ f"{len(self.data_train)}, "
+ f"{len(self.data_val)}, "
+ f"{len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
@@ -223,7 +222,6 @@ def _construct_image_from_string(
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = _select_letter_samples_for_string(string, samples_by_char)
- N = len(sampled_images)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)