summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6871492..eddf341 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -10,6 +10,7 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
+import torch.nn.functional as F
from torchvision.transforms import ToTensor
from text_recognizer.datasets.dataset import Dataset
@@ -23,6 +24,8 @@ from text_recognizer.datasets.util import (
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+MAX_WIDTH = 952
+
class EmnistLinesDataset(Dataset):
"""Synthetic dataset of lines from the Brown corpus with Emnist characters."""
@@ -254,6 +257,14 @@ def construct_image_from_string(
for image in sampled_images:
concatenated_image[:, x : (x + width)] += image
x += next_overlap_width
+
+ if concatenated_image.shape[-1] > MAX_WIDTH:
+ concatenated_image = Tensor(concatenated_image).unsqueeze(0)
+ concatenated_image = F.interpolate(
+ concatenated_image, size=MAX_WIDTH, mode="nearest"
+ )
+ concatenated_image = concatenated_image.squeeze(0).numpy()
+
return np.minimum(255, concatenated_image)