diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
commit | 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch) | |
tree | 5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/networks | |
parent | ffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff) |
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 257 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 165 | ||||
-rw-r--r-- | text_recognizer/networks/residual_network.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 20 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 30 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 5 |
8 files changed, 182 insertions, 310 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 979149f..41fd43f 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,2 +1,2 @@ """Network modules""" -from .image_transformer import ImageTransformer +from .vqvae import VQVAE diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index 9150b55..e23a15d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,158 +1,165 @@ -"""A CNN-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple +"""A Transformer with a cnn backbone. + +The network encodes a image with a convolutional backbone to a latent representation, +i.e. feature maps. A 2d positional encoding is applied to the feature maps for +spatial information. The resulting feature are then set to a transformer decoder +together with the target tokens. + +TODO: Local attention for lower layer in attention. + +""" +import importlib +import math +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +from text_recognizer.networks.transformer import ( + Decoder, + DecoderLayer, + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) +NUM_WORD_PIECES = 1000 -class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction.""" +class CNNTransformer(nn.Module): def __init__( self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - pool_kernel: Optional[Tuple[int, int]] = None, + input_shape: Sequence[int], + output_shape: Sequence[int], + encoder: Union[DictConfig, Dict], + vocab_size: Optional[int] = None, + num_decoder_layers: int = 4, + hidden_dim: int = 256, + num_heads: int = 4, + expansion_dim: int = 1024, + dropout_rate: float = 0.1, + transformer_activation: str = "glu", ) -> None: - super().__init__() - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.backbone = configure_backbone(backbone, backbone_args) - - if pool_kernel is not None: - self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) - else: - self.max_pool = None - - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.pos_dropout = nn.Dropout(p=dropout_rate) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + self.vocab_size = ( + NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size ) + self.hidden_dim = hidden_dim + self.max_output_length = output_shape[0] - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, + # Image backbone + self.encoder = self._configure_encoder(encoder) + self.feature_map_encoding = PositionalEncoding2D( + hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] ) - self.head = nn.Sequential( - # nn.Linear(hidden_dim, hidden_dim * 2), - # activation_function(activation), - nn.Linear(hidden_dim, vocab_size), - ) + # Target token embedding + self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + # Transformer decoder + self.decoder = Decoder( + decoder_layer=DecoderLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion_dim=expansion_dim, + dropout_rate=dropout_rate, + activation=transformer_activation, + ), + num_layers=num_decoder_layers, + norm=nn.LayerNorm(hidden_dim), ) - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + # Classification head + self.head = nn.Linear(hidden_dim, self.vocab_size) - Args: - src (Tensor): Input tensor. + # Initialize weights + self._init_weights() - Returns: - Tensor: A input src to the transformer. + def _init_weights(self) -> None: + """Initialize network weights.""" + self.trg_embedding.weight.data.uniform_(-0.1, 0.1) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-0.1, 0.1) - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - src = self.backbone(src) - - if self.max_pool is not None: - src = self.max_pool(src) - - if self.adaptive_pool is not None and len(src.shape) == 4: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - elif len(src.shape) == 4: - src = rearrange(src, "b c h w -> b (h w) c") + nn.init.kaiming_normal_( + self.feature_map_encoding.weight.data, + a=0, + mode="fan_out", + nonlinearity="relu", + ) + if self.feature_map_encoding.bias is not None: + _, fan_out = nn.init._calculate_fan_in_and_fan_out( + self.feature_map_encoding.weight.data + ) + bound = 1 / math.sqrt(fan_out) + nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + + @staticmethod + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: + encoder = OmegaConf.create(encoder) + network_module = importlib.import_module("text_recognizer.networks") + encoder_class = getattr(network_module, encoder.type) + return encoder_class(**encoder.args) + + def encode(self, image: Tensor) -> Tensor: + """Extracts image features with backbone. - b, t, _ = src.shape + Args: + image (Tensor): Image(s) of handwritten text. - src += self.src_position_embedding[:, :t] - src = self.pos_dropout(src) + Retuns: + Tensor: Image features. - return src + Shapes: + - image: :math: `(B, C, H, W)` + - latent: :math: `(B, T, C)` - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. + """ + # Extract image features. + image_features = self.encoder(image) - Args: - trg (Tensor): Target tensor. + # Add 2d encoding to the feature maps. + image_features = self.feature_map_encoding(image_features) - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + # Collapse features maps height and width. + image_features = rearrange(image_features, "b c h w -> b (h w) c") + return image_features - """ - trg = self.character_embedding(trg.long()) + def decode(self, memory: Tensor, trg: Tensor) -> Tensor: + """Decodes image features with transformer decoder.""" + trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) + trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - + out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) logits = self.head(out) return logits - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits + def predict(self, image: Tensor) -> Tensor: + """Transcribes text in image(s).""" + bsz = image.shape[0] + image_features = self.encode(image) + + output_tokens = ( + (torch.ones((bsz, self.max_output_length)) * self.pad_index) + .type_as(image) + .long() + ) + output_tokens[:, 0] = self.start_index + for i in range(1, self.max_output_length): + trg = output_tokens[:, :i] + output = self.decode(image_features, trg) + output = torch.argmax(output, dim=-1) + output_tokens[:, i] = output[-1:] + + # Set all tokens after end token to be padding. + for i in range(1, self.max_output_length): + indices = output_tokens[:, i - 1] == self.end_index | ( + output_tokens[:, i - 1] == self.pad_index + ) + output_tokens[indices, i] = self.pad_index + + return output_tokens diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py deleted file mode 100644 index a6aaca4..0000000 --- a/text_recognizer/networks/image_transformer.py +++ /dev/null @@ -1,165 +0,0 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( - Decoder, - DecoderLayer, - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class ImageTransformer(nn.Module): - def __init__( - self, - input_shape: Sequence[int], - output_shape: Sequence[int], - encoder: Union[DictConfig, Dict], - vocab_size: Optional[int] = None, - num_decoder_layers: int = 4, - hidden_dim: int = 256, - num_heads: int = 4, - expansion_dim: int = 1024, - dropout_rate: float = 0.1, - transformer_activation: str = "glu", - ) -> None: - self.vocab_size = ( - NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size - ) - self.hidden_dim = hidden_dim - self.max_output_length = output_shape[0] - - # Image backbone - self.encoder = self._configure_encoder(encoder) - self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] - ) - - # Target token embedding - self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - # Transformer decoder - self.decoder = Decoder( - decoder_layer=DecoderLayer( - hidden_dim=hidden_dim, - num_heads=num_heads, - expansion_dim=expansion_dim, - dropout_rate=dropout_rate, - activation=transformer_activation, - ), - num_layers=num_decoder_layers, - norm=nn.LayerNorm(hidden_dim), - ) - - # Classification head - self.head = nn.Linear(hidden_dim, self.vocab_size) - - # Initialize weights - self._init_weights() - - def _init_weights(self) -> None: - """Initialize network weights.""" - self.trg_embedding.weight.data.uniform_(-0.1, 0.1) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-0.1, 0.1) - - nn.init.kaiming_normal_( - self.feature_map_encoding.weight.data, - a=0, - mode="fan_out", - nonlinearity="relu", - ) - if self.feature_map_encoding.bias is not None: - _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.feature_map_encoding.weight.data - ) - bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) - - @staticmethod - def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: - encoder = OmegaConf.create(encoder) - network_module = importlib.import_module("text_recognizer.networks") - encoder_class = getattr(network_module, encoder.type) - return encoder_class(**encoder.args) - - def encode(self, image: Tensor) -> Tensor: - """Extracts image features with backbone. - - Args: - image (Tensor): Image(s) of handwritten text. - - Retuns: - Tensor: Image features. - - Shapes: - - image: :math: `(B, C, H, W)` - - latent: :math: `(B, T, C)` - - """ - # Extract image features. - image_features = self.encoder(image) - - # Add 2d encoding to the feature maps. - image_features = self.feature_map_encoding(image_features) - - # Collapse features maps height and width. - image_features = rearrange(image_features, "b c h w -> b (h w) c") - return image_features - - def decode(self, memory: Tensor, trg: Tensor) -> Tensor: - """Decodes image features with transformer decoder.""" - trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) - trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.trg_position_encoding(trg) - out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) - logits = self.head(out) - return logits - - def predict(self, image: Tensor) -> Tensor: - """Transcribes text in image(s).""" - bsz = image.shape[0] - image_features = self.encode(image) - - output_tokens = ( - (torch.ones((bsz, self.max_output_length)) * self.pad_index) - .type_as(image) - .long() - ) - output_tokens[:, 0] = self.start_index - for i in range(1, self.max_output_length): - trg = output_tokens[:, :i] - output = self.decode(image_features, trg) - output = torch.argmax(output, dim=-1) - output_tokens[:, i] = output[-1:] - - # Set all tokens after end token to be padding. - for i in range(1, self.max_output_length): - indices = output_tokens[:, i - 1] == self.end_index | ( - output_tokens[:, i - 1] == self.pad_index - ) - output_tokens[indices, i] = self.pad_index - - return output_tokens diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index c33f419..da7553d 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + conv3x3 = partial( + Conv2dAuto, + kernel_size=3, + bias=False, + ) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index d7e3d08..b10f93a 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,7 +392,12 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 8847aba..93a1e43 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,7 +44,12 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, ) def _build_decompression_block( @@ -72,8 +77,10 @@ class Decoder(nn.Module): ) ) - if i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) + if self.upsampling and i < len(self.upsampling): + modules.append( + nn.Upsample(size=self.upsampling[i]), + ) if dropout is not None: modules.append(dropout) @@ -102,7 +109,12 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + nn.Conv2d( + self.embedding_dim, + channels[0], + kernel_size=1, + stride=1, + ) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index d3adac5..b0cceed 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,5 +1,5 @@ """CNN encoder for the VQ-VAE.""" -from typing import List, Optional, Tuple, Type +from typing import Sequence, Optional, Tuple, Type import torch from torch import nn @@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], + self, + in_channels: int, + out_channels: int, + dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -36,9 +39,9 @@ class Encoder(nn.Module): def __init__( self, in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], + channels: Sequence[int], + kernel_sizes: Sequence[int], + strides: Sequence[int], num_residual_layers: int, embedding_dim: int, num_embeddings: int, @@ -77,12 +80,12 @@ class Encoder(nn.Module): self.num_embeddings, self.embedding_dim, self.beta ) + @staticmethod def _build_compression_block( - self, in_channels: int, channels: int, - kernel_sizes: List[int], - strides: List[int], + kernel_sizes: Sequence[int], + strides: Sequence[int], activation: Type[nn.Module], dropout: Optional[Type[nn.Module]], ) -> nn.ModuleList: @@ -109,8 +112,8 @@ class Encoder(nn.Module): self, in_channels: int, channels: int, - kernel_sizes: List[int], - strides: List[int], + kernel_sizes: Sequence[int], + strides: Sequence[int], num_residual_layers: int, activation: Type[nn.Module], dropout: Optional[Type[nn.Module]], @@ -135,7 +138,12 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + nn.Conv2d( + channels[-1], + self.embedding_dim, + kernel_size=1, + stride=1, + ) ) return nn.Sequential(*encoder) diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 50448b4..1f08e5e 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,8 +1,7 @@ """The VQ-VAE.""" -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple -import torch from torch import nn from torch import Tensor @@ -25,6 +24,8 @@ class VQVAE(nn.Module): beta: float = 0.25, activation: str = "leaky_relu", dropout_rate: float = 0.0, + *args: Any, + **kwargs: Dict, ) -> None: super().__init__() |