diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6871492..eddf341 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -10,6 +10,7 @@ from loguru import logger import numpy as np import torch from torch import Tensor +import torch.nn.functional as F from torchvision.transforms import ToTensor from text_recognizer.datasets.dataset import Dataset @@ -23,6 +24,8 @@ from text_recognizer.datasets.util import ( DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" +MAX_WIDTH = 952 + class EmnistLinesDataset(Dataset): """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" @@ -254,6 +257,14 @@ def construct_image_from_string( for image in sampled_images: concatenated_image[:, x : (x + width)] += image x += next_overlap_width + + if concatenated_image.shape[-1] > MAX_WIDTH: + concatenated_image = Tensor(concatenated_image).unsqueeze(0) + concatenated_image = F.interpolate( + concatenated_image, size=MAX_WIDTH, mode="nearest" + ) + concatenated_image = concatenated_image.squeeze(0).numpy() + return np.minimum(255, concatenated_image) |