diff options
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | notebooks/00-testing-stuff-out.ipynb | 109 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 9 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 6 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 10 | ||||
-rw-r--r-- | text_recognizer/networks/residual_network.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 12 | ||||
-rw-r--r-- | training/experiments/image_transformer.yaml | 12 | ||||
-rw-r--r-- | training/run_experiment.py | 4 |
13 files changed, 180 insertions, 34 deletions
@@ -11,7 +11,7 @@ TBC Extract text from the iam dataset: ``` -poetry run extract-iam-text --use_words --save_text train.txt --save_tokens letters.txt +poetry run extract-iam-text --use_words --save_text train.txt --save_tokens letters.txt ``` Create word pieces from the extracted training text: @@ -32,7 +32,7 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb - [x] transform that encodes iam targets to wordpieces - [x] transducer loss function - [ ] Train with word pieces -- [ ] Local attention in first layer of transformer +- [ ] Local attention in first layer of transformer - [ ] Halonet encoder - [ ] Implement CPC - [ ] https://arxiv.org/pdf/1905.09272.pdf @@ -59,4 +59,3 @@ export SWEEP_ID=... wandb agent $SWEEP_ID ``` - diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index 8f2e3f8..4c93501 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -26,6 +26,115 @@ { "cell_type": "code", "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../training/experiments/image_transformer.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "conf = OmegaConf.load(path)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "network:\n", + " type: ImageTransformer\n", + " args:\n", + " input_shape: None\n", + " output_shape: None\n", + " encoder:\n", + " type: None\n", + " args: None\n", + " mapping: sentence_piece\n", + " num_decoder_layers: 4\n", + " hidden_dim: 256\n", + " num_heads: 4\n", + " expansion_dim: 1024\n", + " dropout_rate: 0.1\n", + " transformer_activation: glu\n", + "model:\n", + " type: LitTransformerModel\n", + " args:\n", + " optimizer:\n", + " type: MADGRAD\n", + " args:\n", + " lr: 0.01\n", + " momentum: 0.9\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " lr_scheduler:\n", + " type: CosineAnnealingLR\n", + " args:\n", + " T_max: 512\n", + " criterion:\n", + " type: CrossEntropyLoss\n", + " args:\n", + " weight: None\n", + " ignore_index: -100\n", + " reduction: mean\n", + " monitor: val_loss\n", + " mapping: sentence_piece\n", + "data:\n", + " type: IAMExtendedParagraphs\n", + " args:\n", + " batch_size: 16\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + "callbacks:\n", + "- type: ModelCheckpoint\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + "- type: EarlyStopping\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " patience: 10\n", + "trainer:\n", + " args:\n", + " stochastic_weight_avg: true\n", + " auto_scale_batch_size: power\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epocs: 512\n", + " terminate_on_nan: true\n", + " weights_summary: true\n", + "\n" + ] + } + ], + "source": [ + "print(OmegaConf.to_yaml(conf))" + ] + }, + { + "cell_type": "code", + "execution_count": 1, "metadata": { "scrolled": true }, diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index d2529b4..c144341 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -18,10 +18,16 @@ class IAMExtendedParagraphs(BaseDataModule): super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index f588587..314d458 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -161,7 +161,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -282,7 +285,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 1004f48..11d1eb1 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -15,7 +15,7 @@ class LitBaseModel(pl.LightningModule): def __init__( self, - network: Type[nn,Module], + network: Type[nn.Module], optimizer: Union[OmegaConf, Dict], lr_scheduler: Union[OmegaConf, Dict], criterion: Union[OmegaConf, Dict], @@ -40,14 +40,14 @@ class LitBaseModel(pl.LightningModule): args = {} or criterion.args return getattr(nn, criterion.type)(**args) - def _configure_optimizer(self) -> type: + def _configure_optimizer(self) -> torch.optim.Optimizer: """Configures the optimizer.""" args = {} or self._optimizer.args if self._optimizer.type == "MADGRAD": optimizer_class = madgrad.MADGRAD else: optimizer_class = getattr(torch.optim, self._optimizer.type) - return optimizer_class(parameters=self.parameters(), **args) + return optimizer_class(params=self.parameters(), **args) def _configure_lr_scheduler(self) -> Dict[str, Any]: """Configures the lr scheduler.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 3625ab2..983e274 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -19,16 +19,14 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network: Type[nn,Module], + network: Type[nn, Module], optimizer: Union[OmegaConf, Dict], lr_scheduler: Union[OmegaConf, Dict], criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", mapping: Optional[List[str]] = None, ) -> None: - super().__init__( - network, optimizer, lr_scheduler, criterion, monitor - ) + super().__init__(network, optimizer, lr_scheduler, criterion, monitor) self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index aa024e0..85a84d2 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -1,9 +1,9 @@ """A Transformer with a cnn backbone. The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for +i.e. feature maps. A 2d positional encoding is applied to the feature maps for spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. +together with the target tokens. TODO: Local attention for transformer.j @@ -107,9 +107,7 @@ class ImageTransformer(nn.Module): encoder_class = getattr(network_module, encoder.type) return encoder_class(**encoder.args) - def _configure_mapping( - self, mapping: str - ) -> Tuple[List[str], Dict[str, int]]: + def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() @@ -125,7 +123,7 @@ class ImageTransformer(nn.Module): Tensor: Image features. Shapes: - - image: :math: `(B, C, H, W)` + - image: :math: `(B, C, H, W)` - latent: :math: `(B, T, C)` """ diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index c33f419..da7553d 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + conv3x3 = partial( + Conv2dAuto, + kernel_size=3, + bias=False, + ) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index d7e3d08..b10f93a 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,7 +392,12 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 8847aba..67ed0d9 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,7 +44,12 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, ) def _build_decompression_block( @@ -73,7 +78,9 @@ class Decoder(nn.Module): ) if i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) + modules.append( + nn.Upsample(size=self.upsampling[i]), + ) if dropout is not None: modules.append(dropout) @@ -102,7 +109,12 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + nn.Conv2d( + self.embedding_dim, + channels[0], + kernel_size=1, + stride=1, + ) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index d3adac5..ede5c31 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], + self, + in_channels: int, + out_channels: int, + dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -135,7 +138,12 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + nn.Conv2d( + channels[-1], + self.embedding_dim, + kernel_size=1, + stride=1, + ) ) return nn.Sequential(*encoder) diff --git a/training/experiments/image_transformer.yaml b/training/experiments/image_transformer.yaml index 7f0bbb7..012a19b 100644 --- a/training/experiments/image_transformer.yaml +++ b/training/experiments/image_transformer.yaml @@ -1,6 +1,6 @@ network: type: ImageTransformer - args: + args: input_shape: None output_shape: None encoder: @@ -17,20 +17,20 @@ network: model: type: LitTransformerModel args: - optimizer: + optimizer: type: MADGRAD args: lr: 1.0e-2 momentum: 0.9 weight_decay: 0 eps: 1.0e-6 - lr_scheduler: + lr_scheduler: type: CosineAnnealingLR - args: + args: T_max: 512 criterion: type: CrossEntropyLoss - args: + args: weight: None ignore_index: -100 reduction: mean @@ -40,7 +40,7 @@ model: data: type: IAMExtendedParagraphs - args: + args: batch_size: 16 num_workers: 12 train_fraction: 0.8 diff --git a/training/run_experiment.py b/training/run_experiment.py index 8a29555..0a67bfa 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -50,7 +50,9 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_pl_callbacks(args: List[Union[OmegaConf, NamedTuple]]) -> List[Type[pl.callbacks.Callback]]: +def _configure_pl_callbacks( + args: List[Union[OmegaConf, NamedTuple]] +) -> List[Type[pl.callbacks.Callback]]: """Configures PyTorch Lightning callbacks.""" pl_callbacks = [ getattr(pl.callbacks, callback.type)(**callback.args) for callback in args |