diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/emnist.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 14 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 20 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 4 |
4 files changed, 19 insertions, 23 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index bf3faec..824b947 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -10,7 +10,7 @@ import h5py from loguru import logger import numpy as np import toml -from torchvision import transforms +import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, @@ -53,7 +53,7 @@ class EMNIST(BaseDataModule): self.data_train = None self.data_val = None self.data_test = None - self.transform = transforms.Compose([transforms.ToTensor()]) + self.transform = T.Compose([T.ToTensor()]) self.dims = (1, *self.input_shape) self.output_dims = (1,) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 78bc8e1..9c78a22 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -13,7 +13,7 @@ from loguru import logger from PIL import Image, ImageFile, ImageOps import numpy as np from torch import Tensor -from torchvision import transforms +import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from text_recognizer.data.base_dataset import ( @@ -208,7 +208,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li return crops, labels -def get_transform(image_width: int, augment: bool = False) -> transforms.Compose: +def get_transform(image_width: int, augment: bool = False) -> T.Compose: """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" def embed_crop( @@ -237,20 +237,20 @@ def get_transform(image_width: int, augment: bool = False) -> transforms.Compose return image - transfroms_list = [transforms.Lambda(embed_crop)] + transfroms_list = [T.Lambda(embed_crop)] if augment: transfroms_list += [ - transforms.ColorJitter(brightness=(0.8, 1.6)), - transforms.RandomAffine( + T.ColorJitter(brightness=(0.8, 1.6)), + T.RandomAffine( degrees=1, shear=(-30, 20), interpolation=InterpolationMode.BILINEAR, fill=0, ), ] - transfroms_list.append(transforms.ToTensor()) - return transforms.Compose(transfroms_list) + transfroms_list.append(T.ToTensor()) + return T.Compose(transfroms_list) def generate_iam_lines() -> None: diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 24409bc..6022804 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from loguru import logger import numpy as np from PIL import Image, ImageOps -import torchvision.transforms as transforms +import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm @@ -270,31 +270,31 @@ def _load_processed_crops_and_labels( return ordered_crops, ordered_labels -def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose: +def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: """Get transformations for images.""" if augment: transforms_list = [ - transforms.RandomCrop( + T.RandomCrop( size=image_shape, padding=None, pad_if_needed=True, fill=0, padding_mode="constant", ), - transforms.ColorJitter(brightness=(0.8, 1.6)), - transforms.RandomAffine( + T.ColorJitter(brightness=(0.8, 1.6)), + T.RandomAffine( degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, ), ] else: - transforms_list = [transforms.CenterCrop(image_shape)] - transforms_list.append(transforms.ToTensor()) - return transforms.Compose(transforms_list) + transforms_list = [T.CenterCrop(image_shape)] + transforms_list.append(T.ToTensor()) + return T.Compose(transforms_list) -def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: +def get_target_transform(word_pieces: bool) -> Optional[T.Compose]: """Transform emnist characters to word pieces.""" - return transforms.Compose([WordPiece()]) if word_pieces else None + return T.Compose([WordPiece()]) if word_pieces else None def _labels_filename(split: str) -> Path: diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index ad6fa25..00fa2b6 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -1,8 +1,5 @@ """IAM Synthetic Paragraphs Dataset class.""" -import itertools -from pathlib import Path import random -import time from typing import Any, List, Sequence, Tuple from loguru import logger @@ -12,7 +9,6 @@ from PIL import Image from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, - split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import ( |