summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 78bc8e1..9c78a22 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -13,7 +13,7 @@ from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
-from torchvision import transforms
+import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from text_recognizer.data.base_dataset import (
@@ -208,7 +208,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
return crops, labels
-def get_transform(image_width: int, augment: bool = False) -> transforms.Compose:
+def get_transform(image_width: int, augment: bool = False) -> T.Compose:
"""Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
def embed_crop(
@@ -237,20 +237,20 @@ def get_transform(image_width: int, augment: bool = False) -> transforms.Compose
return image
- transfroms_list = [transforms.Lambda(embed_crop)]
+ transfroms_list = [T.Lambda(embed_crop)]
if augment:
transfroms_list += [
- transforms.ColorJitter(brightness=(0.8, 1.6)),
- transforms.RandomAffine(
+ T.ColorJitter(brightness=(0.8, 1.6)),
+ T.RandomAffine(
degrees=1,
shear=(-30, 20),
interpolation=InterpolationMode.BILINEAR,
fill=0,
),
]
- transfroms_list.append(transforms.ToTensor())
- return transforms.Compose(transfroms_list)
+ transfroms_list.append(T.ToTensor())
+ return T.Compose(transfroms_list)
def generate_iam_lines() -> None: