summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/data/iam_paragraphs.py
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index f588587..62c44f9 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
from loguru import logger
import numpy as np
-from PIL import Image, ImageFile, ImageOps
-import torch
+from PIL import Image, ImageOps
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
@@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import (
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam import IAM
+from text_recognizer.data.transforms import WordPiece
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
@@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
super().__init__(batch_size, num_workers)
- # TODO: pass in transform and target transform
- # TODO: pass in mapping
self.augment = augment
+ self.word_pieces = word_pieces
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
@@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
+ target_transform=get_target_transform(self.word_pieces)
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
+ """Transform emnist characters to word pieces."""
+ return transforms.Compose([WordPiece()]) if word_pieces else None
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""