diff options
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index d4b2b40..3ff8a54 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -8,6 +8,7 @@ import h5py from loguru import logger as log import numpy as np import torch +from torch import Tensor from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -190,7 +191,9 @@ def _get_samples_by_char( return samples_by_char -def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): +def _select_letter_samples_for_string( + string: str, samples_by_char: defaultdict +) -> List[Tensor]: null_image = torch.zeros((28, 28), dtype=torch.uint8) sample_image_by_char = {} for char in string: @@ -208,7 +211,7 @@ def _construct_image_from_string( min_overlap: float, max_overlap: float, width: int, -) -> torch.Tensor: +) -> Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = _select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape @@ -218,7 +221,7 @@ def _construct_image_from_string( for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width - return torch.minimum(torch.Tensor([255]), concatenated_image) + return torch.minimum(Tensor([255]), concatenated_image) def _create_dataset_of_images( @@ -228,7 +231,7 @@ def _create_dataset_of_images( min_overlap: float, max_overlap: float, dims: Tuple, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor]: images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) labels = [] for n in range(num_samples): |