summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
commit30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch)
tree08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/emnist_lines.py
parentad3f404d36a9add32992698dd083d368f3b96812 (diff)
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py55
1 files changed, 16 insertions, 39 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 3ff8a54..1a64931 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,7 +1,7 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, List, Tuple
+from typing import DefaultDict, List, Tuple
import attr
import h5py
@@ -9,8 +9,7 @@ from loguru import logger as log
import numpy as np
import torch
from torch import Tensor
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
+import torchvision.transforms as T
from text_recognizer.data.base_data_module import (
BaseDataModule,
@@ -18,12 +17,13 @@ from text_recognizer.data.base_data_module import (
)
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
-from text_recognizer.data.sentence_generator import SentenceGenerator
+from text_recognizer.data.utils.sentence_generator import SentenceGenerator
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json"
+ Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json"
)
SEED = 4711
@@ -37,7 +37,6 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- augment: bool = attr.ib(default=True)
max_length: int = attr.ib(default=128)
min_overlap: float = attr.ib(default=0.0)
max_overlap: float = attr.ib(default=0.33)
@@ -98,21 +97,15 @@ class EMNISTLines(BaseDataModule):
x_val = f["x_val"][:]
y_val = torch.LongTensor(f["y_val"][:])
- self.data_train = BaseDataset(
- x_train, y_train, transform=_get_transform(augment=self.augment)
- )
- self.data_val = BaseDataset(
- x_val, y_val, transform=_get_transform(augment=self.augment)
- )
+ self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
+ self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = torch.LongTensor(f["y_test"][:])
- self.data_test = BaseDataset(
- x_test, y_test, transform=_get_transform(augment=False)
- )
+ self.data_test = BaseDataset(x_test, y_test, transform=self.test_transform)
def __repr__(self) -> str:
"""Return str about dataset."""
@@ -129,6 +122,7 @@ class EMNISTLines(BaseDataModule):
return basic
x, y = next(iter(self.train_dataloader()))
+ x = x[0] if isinstance(x, list) else x
data = (
"Train/val/test sizes: "
f"{len(self.data_train)}, "
@@ -184,7 +178,7 @@ class EMNISTLines(BaseDataModule):
def _get_samples_by_char(
samples: np.ndarray, labels: np.ndarray, mapping: List
-) -> defaultdict:
+) -> DefaultDict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
@@ -192,7 +186,7 @@ def _get_samples_by_char(
def _select_letter_samples_for_string(
- string: str, samples_by_char: defaultdict
+ string: str, samples_by_char: DefaultDict
) -> List[Tensor]:
null_image = torch.zeros((28, 28), dtype=torch.uint8)
sample_image_by_char = {}
@@ -207,7 +201,7 @@ def _select_letter_samples_for_string(
def _construct_image_from_string(
string: str,
- samples_by_char: defaultdict,
+ samples_by_char: DefaultDict,
min_overlap: float,
max_overlap: float,
width: int,
@@ -226,7 +220,7 @@ def _construct_image_from_string(
def _create_dataset_of_images(
num_samples: int,
- samples_by_char: defaultdict,
+ samples_by_char: DefaultDict,
sentence_generator: SentenceGenerator,
min_overlap: float,
max_overlap: float,
@@ -246,25 +240,8 @@ def _create_dataset_of_images(
return images, labels
-def _get_transform(augment: bool = False) -> Callable:
- if not augment:
- return transforms.Compose([transforms.ToTensor()])
- return transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.ColorJitter(brightness=(0.5, 1.0)),
- transforms.RandomAffine(
- degrees=3,
- translate=(0.0, 0.05),
- scale=(0.4, 1.1),
- shear=(-40, 50),
- interpolation=InterpolationMode.BILINEAR,
- fill=0,
- ),
- ]
- )
-
-
def generate_emnist_lines() -> None:
"""Generates a synthetic handwritten dataset and displays info."""
- load_and_print_info(EMNISTLines)
+ transform = load_transform_from_file("transform/emnist_lines.yaml")
+ test_transform = load_transform_from_file("test_transform/default.yaml")
+ load_and_print_info(EMNISTLines(transform=transform, test_transform=test_transform))