diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
commit | 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch) | |
tree | 5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/data | |
parent | ffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff) |
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 15 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 23 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/mappings.py | 16 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 14 |
8 files changed, 54 insertions, 27 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 9a42fa9..3599a8b 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -2,3 +2,6 @@ from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset from .base_data_module import BaseDataModule, load_and_print_info from .download_utils import download_dataset +from .iam_paragraphs import IAMParagraphs +from .iam_synthetic_paragraphs import IAMSyntheticParagraphs +from .iam_extended_paragraphs import IAMExtendedParagraphs diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 72665d0..9650198 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -57,8 +57,8 @@ class EMNISTLines(BaseDataModule): self.num_test = num_test self.emnist = EMNIST() - # TODO: fix mapping self.mapping = self.emnist.mapping + max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index d2529b4..2380660 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -10,18 +10,27 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs class IAMExtendedParagraphs(BaseDataModule): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, + word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, + word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index f588587..62c44f9 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from loguru import logger import numpy as np -from PIL import Image, ImageFile, ImageOps -import torch +from PIL import Image, ImageOps import torchvision.transforms as transforms from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm @@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import ( from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam import IAM +from text_recognizer.data.transforms import WordPiece PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" @@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: super().__init__(batch_size, num_workers) - # TODO: pass in transform and target transform - # TODO: pass in mapping self.augment = augment + self.word_pieces = word_pieces self.mapping, self.inverse_mapping, _ = emnist_mapping( extra_symbols=[NEW_LINE_TOKEN] ) @@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule): data, targets, transform=get_transform(image_shape=self.dims[1:], augment=augment), + target_transform=get_target_transform(self.word_pieces) ) logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: @@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com transforms_list.append(transforms.ToTensor()) return transforms.Compose(transforms_list) +def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: + """Transform emnist characters to word pieces.""" + return transforms.Compose([WordPiece()]) if word_pieces else None def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 60f8a9f..b5f72da 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -89,6 +89,7 @@ class Preprocessor: self.lexicon = None if self.special_tokens is not None: + self.special_tokens += ("#", "*") self.tokens += self.special_tokens self.graphemes += self.special_tokens diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 9f1bd12..4ccc5c2 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -18,6 +18,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.iam_paragraphs import ( get_dataset_properties, get_transform, + get_target_transform, NEW_LINE_TOKEN, IAMParagraphs, IMAGE_SCALE_FACTOR, @@ -41,12 +42,13 @@ class IAMSyntheticParagraphs(IAMParagraphs): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: - super().__init__(batch_size, num_workers, train_fraction, augment) + super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces) def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" @@ -95,6 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): transform=get_transform( image_shape=self.dims[1:], augment=self.augment ), + target_transform=get_target_transform(self.word_pieces) ) def __repr__(self) -> str: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index cfa0ec7..f4016ba 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -8,7 +8,7 @@ import torch from torch import Tensor from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.data.iam_preprocessor import Preprocessor class AbstractMapping(ABC): @@ -57,14 +57,14 @@ class EmnistMapping(AbstractMapping): class WordPieceMapping(EmnistMapping): def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n", ), ) -> None: super().__init__(extra_symbols) self.wordpiece_processor = self._configure_wordpiece_processor( @@ -78,8 +78,8 @@ class WordPieceMapping(EmnistMapping): extra_symbols, ) + @staticmethod def _configure_wordpiece_processor( - self, num_features: int, tokens: str, lexicon: str, @@ -90,7 +90,7 @@ class WordPieceMapping(EmnistMapping): extra_symbols: Optional[Sequence[str]], ) -> Preprocessor: data_dir = ( - (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb") + (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb") if data_dir is None else Path(data_dir) ) @@ -138,6 +138,6 @@ class WordPieceMapping(EmnistMapping): return self.wordpiece_processor.to_index(text) def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: - text = self.mapping.get_text(x) + text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index f53df64..8d1bedd 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence from torch import Tensor -from text_recognizer.datasets.mappings import WordPieceMapping +from text_recognizer.data.mappings import WordPieceMapping class WordPiece: @@ -12,14 +12,15 @@ class WordPiece: def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n",), + max_len: int = 192, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -31,6 +32,7 @@ class WordPiece: special_tokens, extra_symbols, ) + self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x) + return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] |