diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 13:51:15 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 13:51:15 +0200 |
commit | 1d0977585f01c42e9f6280559a1a98037907a62e (patch) | |
tree | 7e86dd71b163f3138ed2658cb52c44e805f21539 /text_recognizer/networks/cnn_transformer.py | |
parent | 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (diff) |
Implemented training script with hydra
Diffstat (limited to 'text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 364 |
1 files changed, 182 insertions, 182 deletions
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index d42c29d..80798e1 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,182 +1,182 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( - Decoder, - DecoderLayer, - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class CNNTransformer(nn.Module): - def __init__( - self, - input_dim: Sequence[int], - output_dims: Sequence[int], - encoder: Union[DictConfig, Dict], - vocab_size: Optional[int] = None, - num_decoder_layers: int = 4, - hidden_dim: int = 256, - num_heads: int = 4, - expansion_dim: int = 1024, - dropout_rate: float = 0.1, - transformer_activation: str = "glu", - *args, - **kwargs, - ) -> None: - super().__init__() - self.vocab_size = ( - NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size - ) - self.pad_index = 3 # TODO: fix me - self.hidden_dim = hidden_dim - self.max_output_length = output_dims[0] - - # Image backbone - self.encoder = self._configure_encoder(encoder) - self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) - self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] - ) - - # Target token embedding - self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding( - hidden_dim, dropout_rate, max_len=output_dims[0] - ) - - # Transformer decoder - self.decoder = Decoder( - decoder_layer=DecoderLayer( - hidden_dim=hidden_dim, - num_heads=num_heads, - expansion_dim=expansion_dim, - dropout_rate=dropout_rate, - activation=transformer_activation, - ), - num_layers=num_decoder_layers, - norm=nn.LayerNorm(hidden_dim), - ) - - # Classification head - self.head = nn.Linear(hidden_dim, self.vocab_size) - - # Initialize weights - self._init_weights() - - def _init_weights(self) -> None: - """Initialize network weights.""" - self.trg_embedding.weight.data.uniform_(-0.1, 0.1) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-0.1, 0.1) - - nn.init.kaiming_normal_( - self.encoder_proj.weight.data, - a=0, - mode="fan_out", - nonlinearity="relu", - ) - if self.encoder_proj.bias is not None: - _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.encoder_proj.weight.data - ) - bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.encoder_proj.bias, -bound, bound) - - @staticmethod - def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: - encoder = OmegaConf.create(encoder) - args = encoder.args or {} - network_module = importlib.import_module("text_recognizer.networks") - encoder_class = getattr(network_module, encoder.type) - return encoder_class(**args) - - def encode(self, image: Tensor) -> Tensor: - """Extracts image features with backbone. - - Args: - image (Tensor): Image(s) of handwritten text. - - Retuns: - Tensor: Image features. - - Shapes: - - image: :math: `(B, C, H, W)` - - latent: :math: `(B, T, C)` - - """ - # Extract image features. - image_features = self.encoder(image) - image_features = self.encoder_proj(image_features) - - # Add 2d encoding to the feature maps. - image_features = self.feature_map_encoding(image_features) - - # Collapse features maps height and width. - image_features = rearrange(image_features, "b c h w -> b (h w) c") - return image_features - - def decode(self, memory: Tensor, trg: Tensor) -> Tensor: - """Decodes image features with transformer decoder.""" - trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) - trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) - trg = rearrange(trg, "b t d -> t b d") - trg = self.trg_position_encoding(trg) - trg = rearrange(trg, "t b d -> b t d") - out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) - logits = self.head(out) - return logits - - def forward(self, image: Tensor, trg: Tensor) -> Tensor: - image_features = self.encode(image) - output = self.decode(image_features, trg) - output = rearrange(output, "b t c -> b c t") - return output - - def predict(self, image: Tensor) -> Tensor: - """Transcribes text in image(s).""" - bsz = image.shape[0] - image_features = self.encode(image) - - output_tokens = ( - (torch.ones((bsz, self.max_output_length)) * self.pad_index) - .type_as(image) - .long() - ) - output_tokens[:, 0] = self.start_index - for i in range(1, self.max_output_length): - trg = output_tokens[:, :i] - output = self.decode(image_features, trg) - output = torch.argmax(output, dim=-1) - output_tokens[:, i] = output[-1:] - - # Set all tokens after end token to be padding. - for i in range(1, self.max_output_length): - indices = output_tokens[:, i - 1] == self.end_index | ( - output_tokens[:, i - 1] == self.pad_index - ) - output_tokens[indices, i] = self.pad_index - - return output_tokens +# """A Transformer with a cnn backbone. +# +# The network encodes a image with a convolutional backbone to a latent representation, +# i.e. feature maps. A 2d positional encoding is applied to the feature maps for +# spatial information. The resulting feature are then set to a transformer decoder +# together with the target tokens. +# +# TODO: Local attention for lower layer in attention. +# +# """ +# import importlib +# import math +# from typing import Dict, Optional, Union, Sequence, Type +# +# from einops import rearrange +# from omegaconf import DictConfig, OmegaConf +# import torch +# from torch import nn +# from torch import Tensor +# +# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +# from text_recognizer.networks.transformer import ( +# Decoder, +# DecoderLayer, +# PositionalEncoding, +# PositionalEncoding2D, +# target_padding_mask, +# ) +# +# NUM_WORD_PIECES = 1000 +# +# +# class CNNTransformer(nn.Module): +# def __init__( +# self, +# input_dim: Sequence[int], +# output_dims: Sequence[int], +# encoder: Union[DictConfig, Dict], +# vocab_size: Optional[int] = None, +# num_decoder_layers: int = 4, +# hidden_dim: int = 256, +# num_heads: int = 4, +# expansion_dim: int = 1024, +# dropout_rate: float = 0.1, +# transformer_activation: str = "glu", +# *args, +# **kwargs, +# ) -> None: +# super().__init__() +# self.vocab_size = ( +# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size +# ) +# self.pad_index = 3 # TODO: fix me +# self.hidden_dim = hidden_dim +# self.max_output_length = output_dims[0] +# +# # Image backbone +# self.encoder = self._configure_encoder(encoder) +# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) +# self.feature_map_encoding = PositionalEncoding2D( +# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] +# ) +# +# # Target token embedding +# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) +# self.trg_position_encoding = PositionalEncoding( +# hidden_dim, dropout_rate, max_len=output_dims[0] +# ) +# +# # Transformer decoder +# self.decoder = Decoder( +# decoder_layer=DecoderLayer( +# hidden_dim=hidden_dim, +# num_heads=num_heads, +# expansion_dim=expansion_dim, +# dropout_rate=dropout_rate, +# activation=transformer_activation, +# ), +# num_layers=num_decoder_layers, +# norm=nn.LayerNorm(hidden_dim), +# ) +# +# # Classification head +# self.head = nn.Linear(hidden_dim, self.vocab_size) +# +# # Initialize weights +# self._init_weights() +# +# def _init_weights(self) -> None: +# """Initialize network weights.""" +# self.trg_embedding.weight.data.uniform_(-0.1, 0.1) +# self.head.bias.data.zero_() +# self.head.weight.data.uniform_(-0.1, 0.1) +# +# nn.init.kaiming_normal_( +# self.encoder_proj.weight.data, +# a=0, +# mode="fan_out", +# nonlinearity="relu", +# ) +# if self.encoder_proj.bias is not None: +# _, fan_out = nn.init._calculate_fan_in_and_fan_out( +# self.encoder_proj.weight.data +# ) +# bound = 1 / math.sqrt(fan_out) +# nn.init.normal_(self.encoder_proj.bias, -bound, bound) +# +# @staticmethod +# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: +# encoder = OmegaConf.create(encoder) +# args = encoder.args or {} +# network_module = importlib.import_module("text_recognizer.networks") +# encoder_class = getattr(network_module, encoder.type) +# return encoder_class(**args) +# +# def encode(self, image: Tensor) -> Tensor: +# """Extracts image features with backbone. +# +# Args: +# image (Tensor): Image(s) of handwritten text. +# +# Retuns: +# Tensor: Image features. +# +# Shapes: +# - image: :math: `(B, C, H, W)` +# - latent: :math: `(B, T, C)` +# +# """ +# # Extract image features. +# image_features = self.encoder(image) +# image_features = self.encoder_proj(image_features) +# +# # Add 2d encoding to the feature maps. +# image_features = self.feature_map_encoding(image_features) +# +# # Collapse features maps height and width. +# image_features = rearrange(image_features, "b c h w -> b (h w) c") +# return image_features +# +# def decode(self, memory: Tensor, trg: Tensor) -> Tensor: +# """Decodes image features with transformer decoder.""" +# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) +# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) +# trg = rearrange(trg, "b t d -> t b d") +# trg = self.trg_position_encoding(trg) +# trg = rearrange(trg, "t b d -> b t d") +# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) +# logits = self.head(out) +# return logits +# +# def forward(self, image: Tensor, trg: Tensor) -> Tensor: +# image_features = self.encode(image) +# output = self.decode(image_features, trg) +# output = rearrange(output, "b t c -> b c t") +# return output +# +# def predict(self, image: Tensor) -> Tensor: +# """Transcribes text in image(s).""" +# bsz = image.shape[0] +# image_features = self.encode(image) +# +# output_tokens = ( +# (torch.ones((bsz, self.max_output_length)) * self.pad_index) +# .type_as(image) +# .long() +# ) +# output_tokens[:, 0] = self.start_index +# for i in range(1, self.max_output_length): +# trg = output_tokens[:, :i] +# output = self.decode(image_features, trg) +# output = torch.argmax(output, dim=-1) +# output_tokens[:, i] = output[-1:] +# +# # Set all tokens after end token to be padding. +# for i in range(1, self.max_output_length): +# indices = output_tokens[:, i - 1] == self.end_index | ( +# output_tokens[:, i - 1] == self.pad_index +# ) +# output_tokens[indices, i] = self.pad_index +# +# return output_tokens |