summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:00 +0200
commit3a21c29e2eff4378c63717f8920ca3ccbfef013c (patch)
treeba46504d7baa8d4fb5bfd473acf99a7a184b330c /text_recognizer/data/emnist_lines.py
parent75eb34020620584247313926527019471411f6af (diff)
Lint files
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index d4b2b40..3ff8a54 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -8,6 +8,7 @@ import h5py
from loguru import logger as log
import numpy as np
import torch
+from torch import Tensor
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
@@ -190,7 +191,9 @@ def _get_samples_by_char(
return samples_by_char
-def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
+def _select_letter_samples_for_string(
+ string: str, samples_by_char: defaultdict
+) -> List[Tensor]:
null_image = torch.zeros((28, 28), dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
@@ -208,7 +211,7 @@ def _construct_image_from_string(
min_overlap: float,
max_overlap: float,
width: int,
-) -> torch.Tensor:
+) -> Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = _select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
@@ -218,7 +221,7 @@ def _construct_image_from_string(
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
- return torch.minimum(torch.Tensor([255]), concatenated_image)
+ return torch.minimum(Tensor([255]), concatenated_image)
def _create_dataset_of_images(
@@ -228,7 +231,7 @@ def _create_dataset_of_images(
min_overlap: float,
max_overlap: float,
dims: Tuple,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[Tensor, Tensor]:
images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
labels = []
for n in range(num_samples):