diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
commit | 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch) | |
tree | 5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/data/iam_paragraphs.py | |
parent | ffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff) |
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index f588587..62c44f9 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from loguru import logger import numpy as np -from PIL import Image, ImageFile, ImageOps -import torch +from PIL import Image, ImageOps import torchvision.transforms as transforms from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm @@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import ( from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam import IAM +from text_recognizer.data.transforms import WordPiece PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" @@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: super().__init__(batch_size, num_workers) - # TODO: pass in transform and target transform - # TODO: pass in mapping self.augment = augment + self.word_pieces = word_pieces self.mapping, self.inverse_mapping, _ = emnist_mapping( extra_symbols=[NEW_LINE_TOKEN] ) @@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule): data, targets, transform=get_transform(image_shape=self.dims[1:], augment=augment), + target_transform=get_target_transform(self.word_pieces) ) logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: @@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com transforms_list.append(transforms.ToTensor()) return transforms.Compose(transforms_list) +def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: + """Transform emnist characters to word pieces.""" + return transforms.Compose([WordPiece()]) if word_pieces else None def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" |