summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 24409bc..6022804 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
-import torchvision.transforms as transforms
+import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
@@ -270,31 +270,31 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
+def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [
- transforms.RandomCrop(
+ T.RandomCrop(
size=image_shape,
padding=None,
pad_if_needed=True,
fill=0,
padding_mode="constant",
),
- transforms.ColorJitter(brightness=(0.8, 1.6)),
- transforms.RandomAffine(
+ T.ColorJitter(brightness=(0.8, 1.6)),
+ T.RandomAffine(
degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
- transforms_list = [transforms.CenterCrop(image_shape)]
- transforms_list.append(transforms.ToTensor())
- return transforms.Compose(transforms_list)
+ transforms_list = [T.CenterCrop(image_shape)]
+ transforms_list.append(T.ToTensor())
+ return T.Compose(transforms_list)
-def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
+def get_target_transform(word_pieces: bool) -> Optional[T.Compose]:
"""Transform emnist characters to word pieces."""
- return transforms.Compose([WordPiece()]) if word_pieces else None
+ return T.Compose([WordPiece()]) if word_pieces else None
def _labels_filename(split: str) -> Path: