diff options
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 24409bc..6022804 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from loguru import logger import numpy as np from PIL import Image, ImageOps -import torchvision.transforms as transforms +import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm @@ -270,31 +270,31 @@ def _load_processed_crops_and_labels( return ordered_crops, ordered_labels -def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose: +def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: """Get transformations for images.""" if augment: transforms_list = [ - transforms.RandomCrop( + T.RandomCrop( size=image_shape, padding=None, pad_if_needed=True, fill=0, padding_mode="constant", ), - transforms.ColorJitter(brightness=(0.8, 1.6)), - transforms.RandomAffine( + T.ColorJitter(brightness=(0.8, 1.6)), + T.RandomAffine( degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, ), ] else: - transforms_list = [transforms.CenterCrop(image_shape)] - transforms_list.append(transforms.ToTensor()) - return transforms.Compose(transforms_list) + transforms_list = [T.CenterCrop(image_shape)] + transforms_list.append(T.ToTensor()) + return T.Compose(transforms_list) -def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: +def get_target_transform(word_pieces: bool) -> Optional[T.Compose]: """Transform emnist characters to word pieces.""" - return transforms.Compose([WordPiece()]) if word_pieces else None + return T.Compose([WordPiece()]) if word_pieces else None def _labels_filename(split: str) -> Path: |