diff options
-rw-r--r-- | notebooks/03-look-at-iam-paragraphs.ipynb | 90 | ||||
-rw-r--r-- | text_recognizer/data/mapping.py | 8 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 1 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 5 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 23 | ||||
-rw-r--r-- | training/configs/image_transformer.yaml | 28 |
6 files changed, 123 insertions, 32 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 4b82034..cfa0ba5 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -25,12 +25,13 @@ " sys.path.append('..')\n", "\n", "from text_recognizer.data.iam_paragraphs import IAMParagraphs\n", - "from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs" + "from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs\n", + "from text_recognizer.data.iam_extended_paragraphs import IAMExtendedParagraphs" ] }, { "cell_type": "code", - "execution_count": 162, + "execution_count": 2, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -47,6 +48,65 @@ }, { "cell_type": "code", + "execution_count": 3, + "id": "c6188bce", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-04-11 21:49:35.313 | INFO | text_recognizer.data.iam_paragraphs:setup:106 - Loading IAM paragraph regions and lines for None...\n", + "2021-04-11 21:49:51.802 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:77 - IAM Synthetic dataset steup for stage None\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IAM Original and Synthetic Paragraphs Dataset\n", + "Num classes: 84\n", + "Dims: (1, 576, 640)\n", + "Output dims: (682, 1)\n", + "Train/val/test sizes: 19942, 262, 231\n", + "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0099), tensor(0.0553), tensor(1.))\n", + "Train Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n", + "Test Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0284), tensor(0.0846), tensor(0.9373))\n", + "Test Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n", + "\n" + ] + } + ], + "source": [ + "dataset = IAMExtendedParagraphs()\n", + "dataset.prepare_data()\n", + "dataset.setup()\n", + "print(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1b3c7bdd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1246.375" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "19942 / 16" + ] + }, + { + "cell_type": "code", "execution_count": 4, "id": "42501428", "metadata": {}, @@ -152,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 165, + "execution_count": 5, "id": "45649194", "metadata": {}, "outputs": [], @@ -163,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 166, + "execution_count": 6, "id": "0fc13f9f", "metadata": {}, "outputs": [], @@ -181,6 +241,27 @@ }, { "cell_type": "code", + "execution_count": 8, + "id": "fb0afccf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1004" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(processor.tokens)" + ] + }, + { + "cell_type": "code", "execution_count": 167, "id": "d08a0259", "metadata": {}, @@ -435,7 +516,6 @@ } ], "source": [ - "\n", "# Testing\n", "\n", "for _ in range(5):\n", diff --git a/text_recognizer/data/mapping.py b/text_recognizer/data/mapping.py new file mode 100644 index 0000000..f0edf3f --- /dev/null +++ b/text_recognizer/data/mapping.py @@ -0,0 +1,8 @@ +"""Mapping to and from word pieces.""" +from pathlib import Path + + +class WordPieces: + + def __init__(self) -> None: + pass diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 0928e6c..c6d5d73 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -60,6 +60,7 @@ class LitBaseModel(pl.LightningModule): scheduler["scheduler"] = getattr( torch.optim.lr_scheduler, self._lr_scheduler.type )(optimizer, **args) + return scheduler def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index b23685b..7dc1352 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple, Type from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl @@ -19,7 +19,7 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network: Type[nn, Module], + network: Type[nn.Module], optimizer: Union[DictConfig, Dict], lr_scheduler: Union[DictConfig, Dict], criterion: Union[DictConfig, Dict], @@ -27,7 +27,6 @@ class LitTransformerModel(LitBaseModel): mapping: Optional[List[str]] = None, ) -> None: super().__init__(network, optimizer, lr_scheduler, criterion, monitor) - self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 9ed67a4..daededa 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -10,16 +10,15 @@ TODO: Local attention for lower layer in attention. """ import importlib import math -from typing import Dict, List, Union, Sequence, Tuple, Type +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -import torchvision -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS from text_recognizer.networks.transformer import ( Decoder, DecoderLayer, @@ -28,6 +27,8 @@ from text_recognizer.networks.transformer import ( target_padding_mask, ) +NUM_WORD_PIECES = 1000 + class ImageTransformer(nn.Module): def __init__( @@ -35,7 +36,7 @@ class ImageTransformer(nn.Module): input_shape: Sequence[int], output_shape: Sequence[int], encoder: Union[DictConfig, Dict], - mapping: str, + vocab_size: Optional[int] = None, num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, @@ -43,14 +44,9 @@ class ImageTransformer(nn.Module): dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: - # Configure mapping - mapping, inverse_mapping = self._configure_mapping(mapping) - self.vocab_size = len(mapping) + self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size self.hidden_dim = hidden_dim self.max_output_length = output_shape[0] - self.start_index = inverse_mapping["<s>"] - self.end_index = inverse_mapping["<e>"] - self.pad_index = inverse_mapping["<p>"] # Image backbone self.encoder = self._configure_encoder(encoder) @@ -107,13 +103,6 @@ class ImageTransformer(nn.Module): encoder_class = getattr(network_module, encoder.type) return encoder_class(**encoder.args) - def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: - """Configures mapping.""" - # TODO: Fix me!!! - if mapping == "emnist": - mapping, inverse_mapping, _ = emnist_mapping() - return mapping, inverse_mapping - def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml index bedcbb5..88c05c2 100644 --- a/training/configs/image_transformer.yaml +++ b/training/configs/image_transformer.yaml @@ -1,7 +1,7 @@ seed: 4711 network: - desc: null + desc: Configuration of the PyTorch neural network. type: ImageTransformer args: encoder: @@ -15,20 +15,24 @@ network: transformer_activation: glu model: - desc: null + desc: Configuration of the PyTorch Lightning model. type: LitTransformerModel args: optimizer: type: MADGRAD args: - lr: 1.0e-2 + lr: 1.0e-3 momentum: 0.9 weight_decay: 0 eps: 1.0e-6 lr_scheduler: - type: CosineAnnealingLR + type: OneCycle args: - T_max: 512 + interval: &interval step + max_lr: 1.0e-3 + three_phase: true + epochs: 512 + steps_per_epoch: 1246 # num_samples / batch_size criterion: type: CrossEntropyLoss args: @@ -39,7 +43,7 @@ model: mapping: sentence_piece data: - desc: null + desc: Configuration of the training/test data. type: IAMExtendedParagraphs args: batch_size: 16 @@ -52,6 +56,16 @@ callbacks: args: monitor: val_loss mode: min + - type: StochasticWeightAveraging + args: + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null + - type: LearningRateMonitor + args: + logging_interval: *interval - type: EarlyStopping args: monitor: val_loss @@ -59,7 +73,7 @@ callbacks: patience: 10 trainer: - desc: null + desc: Configuration of the PyTorch Lightning Trainer. args: stochastic_weight_avg: true auto_scale_batch_size: binsearch |