summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-26 00:35:02 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-26 00:35:02 +0200
commit22e36513dd43d2e2ca82ca28a1ea757c5663676a (patch)
tree54285c3c30a02b00af989078bf61c122b9eccabd /text_recognizer/data
parent9c3a8753d95ecb70a84e1eb40933590a510abfc4 (diff)
Updates
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/emnist.py4
-rw-r--r--text_recognizer/data/iam_lines.py14
-rw-r--r--text_recognizer/data/iam_paragraphs.py20
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
4 files changed, 19 insertions, 23 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index bf3faec..824b947 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -10,7 +10,7 @@ import h5py
from loguru import logger
import numpy as np
import toml
-from torchvision import transforms
+import torchvision.transforms as T
from text_recognizer.data.base_data_module import (
BaseDataModule,
@@ -53,7 +53,7 @@ class EMNIST(BaseDataModule):
self.data_train = None
self.data_val = None
self.data_test = None
- self.transform = transforms.Compose([transforms.ToTensor()])
+ self.transform = T.Compose([T.ToTensor()])
self.dims = (1, *self.input_shape)
self.output_dims = (1,)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 78bc8e1..9c78a22 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -13,7 +13,7 @@ from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
-from torchvision import transforms
+import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from text_recognizer.data.base_dataset import (
@@ -208,7 +208,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
return crops, labels
-def get_transform(image_width: int, augment: bool = False) -> transforms.Compose:
+def get_transform(image_width: int, augment: bool = False) -> T.Compose:
"""Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
def embed_crop(
@@ -237,20 +237,20 @@ def get_transform(image_width: int, augment: bool = False) -> transforms.Compose
return image
- transfroms_list = [transforms.Lambda(embed_crop)]
+ transfroms_list = [T.Lambda(embed_crop)]
if augment:
transfroms_list += [
- transforms.ColorJitter(brightness=(0.8, 1.6)),
- transforms.RandomAffine(
+ T.ColorJitter(brightness=(0.8, 1.6)),
+ T.RandomAffine(
degrees=1,
shear=(-30, 20),
interpolation=InterpolationMode.BILINEAR,
fill=0,
),
]
- transfroms_list.append(transforms.ToTensor())
- return transforms.Compose(transfroms_list)
+ transfroms_list.append(T.ToTensor())
+ return T.Compose(transfroms_list)
def generate_iam_lines() -> None:
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 24409bc..6022804 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
-import torchvision.transforms as transforms
+import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
@@ -270,31 +270,31 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
+def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [
- transforms.RandomCrop(
+ T.RandomCrop(
size=image_shape,
padding=None,
pad_if_needed=True,
fill=0,
padding_mode="constant",
),
- transforms.ColorJitter(brightness=(0.8, 1.6)),
- transforms.RandomAffine(
+ T.ColorJitter(brightness=(0.8, 1.6)),
+ T.RandomAffine(
degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
- transforms_list = [transforms.CenterCrop(image_shape)]
- transforms_list.append(transforms.ToTensor())
- return transforms.Compose(transforms_list)
+ transforms_list = [T.CenterCrop(image_shape)]
+ transforms_list.append(T.ToTensor())
+ return T.Compose(transforms_list)
-def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
+def get_target_transform(word_pieces: bool) -> Optional[T.Compose]:
"""Transform emnist characters to word pieces."""
- return transforms.Compose([WordPiece()]) if word_pieces else None
+ return T.Compose([WordPiece()]) if word_pieces else None
def _labels_filename(split: str) -> Path:
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index ad6fa25..00fa2b6 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -1,8 +1,5 @@
"""IAM Synthetic Paragraphs Dataset class."""
-import itertools
-from pathlib import Path
import random
-import time
from typing import Any, List, Sequence, Tuple
from loguru import logger
@@ -12,7 +9,6 @@ from PIL import Image
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
- split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import (