summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/networks
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/__init__.py2
-rw-r--r--text_recognizer/networks/cnn_transformer.py257
-rw-r--r--text_recognizer/networks/image_transformer.py165
-rw-r--r--text_recognizer/networks/residual_network.py6
-rw-r--r--text_recognizer/networks/transducer/transducer.py7
-rw-r--r--text_recognizer/networks/vqvae/decoder.py20
-rw-r--r--text_recognizer/networks/vqvae/encoder.py30
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py5
8 files changed, 182 insertions, 310 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 979149f..41fd43f 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,2 +1,2 @@
"""Network modules"""
-from .image_transformer import ImageTransformer
+from .vqvae import VQVAE
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index 9150b55..e23a15d 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -1,158 +1,165 @@
-"""A CNN-Transformer for image to text recognition."""
-from typing import Dict, Optional, Tuple
+"""A Transformer with a cnn backbone.
+
+The network encodes a image with a convolutional backbone to a latent representation,
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+spatial information. The resulting feature are then set to a transformer decoder
+together with the target tokens.
+
+TODO: Local attention for lower layer in attention.
+
+"""
+import importlib
+import math
+from typing import Dict, Optional, Union, Sequence, Type
from einops import rearrange
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.util import configure_backbone
+from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
+from text_recognizer.networks.transformer import (
+ Decoder,
+ DecoderLayer,
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
+NUM_WORD_PIECES = 1000
-class CNNTransformer(nn.Module):
- """CNN+Transfomer for image to sequence prediction."""
+class CNNTransformer(nn.Module):
def __init__(
self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- adaptive_pool_dim: Tuple,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- backbone: str,
- backbone_args: Optional[Dict] = None,
- activation: str = "gelu",
- pool_kernel: Optional[Tuple[int, int]] = None,
+ input_shape: Sequence[int],
+ output_shape: Sequence[int],
+ encoder: Union[DictConfig, Dict],
+ vocab_size: Optional[int] = None,
+ num_decoder_layers: int = 4,
+ hidden_dim: int = 256,
+ num_heads: int = 4,
+ expansion_dim: int = 1024,
+ dropout_rate: float = 0.1,
+ transformer_activation: str = "glu",
) -> None:
- super().__init__()
- self.trg_pad_index = trg_pad_index
- self.vocab_size = vocab_size
- self.backbone = configure_backbone(backbone, backbone_args)
-
- if pool_kernel is not None:
- self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
- else:
- self.max_pool = None
-
- self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
-
- self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.pos_dropout = nn.Dropout(p=dropout_rate)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
-
- nn.init.normal_(self.character_embedding.weight, std=0.02)
-
- self.adaptive_pool = (
- nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ self.vocab_size = (
+ NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
)
+ self.hidden_dim = hidden_dim
+ self.max_output_length = output_shape[0]
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
+ # Image backbone
+ self.encoder = self._configure_encoder(encoder)
+ self.feature_map_encoding = PositionalEncoding2D(
+ hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
)
- self.head = nn.Sequential(
- # nn.Linear(hidden_dim, hidden_dim * 2),
- # activation_function(activation),
- nn.Linear(hidden_dim, vocab_size),
- )
+ # Target token embedding
+ self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ # Transformer decoder
+ self.decoder = Decoder(
+ decoder_layer=DecoderLayer(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion_dim=expansion_dim,
+ dropout_rate=dropout_rate,
+ activation=transformer_activation,
+ ),
+ num_layers=num_decoder_layers,
+ norm=nn.LayerNorm(hidden_dim),
)
- def extract_image_features(self, src: Tensor) -> Tensor:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+ # Classification head
+ self.head = nn.Linear(hidden_dim, self.vocab_size)
- Args:
- src (Tensor): Input tensor.
+ # Initialize weights
+ self._init_weights()
- Returns:
- Tensor: A input src to the transformer.
+ def _init_weights(self) -> None:
+ """Initialize network weights."""
+ self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
+ self.head.bias.data.zero_()
+ self.head.weight.data.uniform_(-0.1, 0.1)
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
-
- src = self.backbone(src)
-
- if self.max_pool is not None:
- src = self.max_pool(src)
-
- if self.adaptive_pool is not None and len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b w c h")
- src = self.adaptive_pool(src)
- src = src.squeeze(3)
- elif len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b (h w) c")
+ nn.init.kaiming_normal_(
+ self.feature_map_encoding.weight.data,
+ a=0,
+ mode="fan_out",
+ nonlinearity="relu",
+ )
+ if self.feature_map_encoding.bias is not None:
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(
+ self.feature_map_encoding.weight.data
+ )
+ bound = 1 / math.sqrt(fan_out)
+ nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+
+ @staticmethod
+ def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
+ encoder = OmegaConf.create(encoder)
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_class = getattr(network_module, encoder.type)
+ return encoder_class(**encoder.args)
+
+ def encode(self, image: Tensor) -> Tensor:
+ """Extracts image features with backbone.
- b, t, _ = src.shape
+ Args:
+ image (Tensor): Image(s) of handwritten text.
- src += self.src_position_embedding[:, :t]
- src = self.pos_dropout(src)
+ Retuns:
+ Tensor: Image features.
- return src
+ Shapes:
+ - image: :math: `(B, C, H, W)`
+ - latent: :math: `(B, T, C)`
- def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
+ """
+ # Extract image features.
+ image_features = self.encoder(image)
- Args:
- trg (Tensor): Target tensor.
+ # Add 2d encoding to the feature maps.
+ image_features = self.feature_map_encoding(image_features)
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+ # Collapse features maps height and width.
+ image_features = rearrange(image_features, "b c h w -> b (h w) c")
+ return image_features
- """
- trg = self.character_embedding(trg.long())
+ def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
+ """Decodes image features with transformer decoder."""
+ trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
+ trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
trg = self.trg_position_encoding(trg)
- return trg
-
- def decode_image_features(
- self, image_features: Tensor, trg: Optional[Tensor] = None
- ) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(image_features, trg, trg_mask=trg_mask)
-
+ out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
logits = self.head(out)
return logits
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- image_features = self.extract_image_features(x)
- logits = self.decode_image_features(image_features, trg)
- return logits
+ def predict(self, image: Tensor) -> Tensor:
+ """Transcribes text in image(s)."""
+ bsz = image.shape[0]
+ image_features = self.encode(image)
+
+ output_tokens = (
+ (torch.ones((bsz, self.max_output_length)) * self.pad_index)
+ .type_as(image)
+ .long()
+ )
+ output_tokens[:, 0] = self.start_index
+ for i in range(1, self.max_output_length):
+ trg = output_tokens[:, :i]
+ output = self.decode(image_features, trg)
+ output = torch.argmax(output, dim=-1)
+ output_tokens[:, i] = output[-1:]
+
+ # Set all tokens after end token to be padding.
+ for i in range(1, self.max_output_length):
+ indices = output_tokens[:, i - 1] == self.end_index | (
+ output_tokens[:, i - 1] == self.pad_index
+ )
+ output_tokens[indices, i] = self.pad_index
+
+ return output_tokens
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
deleted file mode 100644
index a6aaca4..0000000
--- a/text_recognizer/networks/image_transformer.py
+++ /dev/null
@@ -1,165 +0,0 @@
-"""A Transformer with a cnn backbone.
-
-The network encodes a image with a convolutional backbone to a latent representation,
-i.e. feature maps. A 2d positional encoding is applied to the feature maps for
-spatial information. The resulting feature are then set to a transformer decoder
-together with the target tokens.
-
-TODO: Local attention for lower layer in attention.
-
-"""
-import importlib
-import math
-from typing import Dict, Optional, Union, Sequence, Type
-
-from einops import rearrange
-from omegaconf import DictConfig, OmegaConf
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
-from text_recognizer.networks.transformer import (
- Decoder,
- DecoderLayer,
- PositionalEncoding,
- PositionalEncoding2D,
- target_padding_mask,
-)
-
-NUM_WORD_PIECES = 1000
-
-
-class ImageTransformer(nn.Module):
- def __init__(
- self,
- input_shape: Sequence[int],
- output_shape: Sequence[int],
- encoder: Union[DictConfig, Dict],
- vocab_size: Optional[int] = None,
- num_decoder_layers: int = 4,
- hidden_dim: int = 256,
- num_heads: int = 4,
- expansion_dim: int = 1024,
- dropout_rate: float = 0.1,
- transformer_activation: str = "glu",
- ) -> None:
- self.vocab_size = (
- NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
- )
- self.hidden_dim = hidden_dim
- self.max_output_length = output_shape[0]
-
- # Image backbone
- self.encoder = self._configure_encoder(encoder)
- self.feature_map_encoding = PositionalEncoding2D(
- hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
- )
-
- # Target token embedding
- self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
-
- # Transformer decoder
- self.decoder = Decoder(
- decoder_layer=DecoderLayer(
- hidden_dim=hidden_dim,
- num_heads=num_heads,
- expansion_dim=expansion_dim,
- dropout_rate=dropout_rate,
- activation=transformer_activation,
- ),
- num_layers=num_decoder_layers,
- norm=nn.LayerNorm(hidden_dim),
- )
-
- # Classification head
- self.head = nn.Linear(hidden_dim, self.vocab_size)
-
- # Initialize weights
- self._init_weights()
-
- def _init_weights(self) -> None:
- """Initialize network weights."""
- self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
- self.head.bias.data.zero_()
- self.head.weight.data.uniform_(-0.1, 0.1)
-
- nn.init.kaiming_normal_(
- self.feature_map_encoding.weight.data,
- a=0,
- mode="fan_out",
- nonlinearity="relu",
- )
- if self.feature_map_encoding.bias is not None:
- _, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.feature_map_encoding.weight.data
- )
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
-
- @staticmethod
- def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
- encoder = OmegaConf.create(encoder)
- network_module = importlib.import_module("text_recognizer.networks")
- encoder_class = getattr(network_module, encoder.type)
- return encoder_class(**encoder.args)
-
- def encode(self, image: Tensor) -> Tensor:
- """Extracts image features with backbone.
-
- Args:
- image (Tensor): Image(s) of handwritten text.
-
- Retuns:
- Tensor: Image features.
-
- Shapes:
- - image: :math: `(B, C, H, W)`
- - latent: :math: `(B, T, C)`
-
- """
- # Extract image features.
- image_features = self.encoder(image)
-
- # Add 2d encoding to the feature maps.
- image_features = self.feature_map_encoding(image_features)
-
- # Collapse features maps height and width.
- image_features = rearrange(image_features, "b c h w -> b (h w) c")
- return image_features
-
- def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
- """Decodes image features with transformer decoder."""
- trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
- trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = self.trg_position_encoding(trg)
- out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
- logits = self.head(out)
- return logits
-
- def predict(self, image: Tensor) -> Tensor:
- """Transcribes text in image(s)."""
- bsz = image.shape[0]
- image_features = self.encode(image)
-
- output_tokens = (
- (torch.ones((bsz, self.max_output_length)) * self.pad_index)
- .type_as(image)
- .long()
- )
- output_tokens[:, 0] = self.start_index
- for i in range(1, self.max_output_length):
- trg = output_tokens[:, :i]
- output = self.decode(image_features, trg)
- output = torch.argmax(output, dim=-1)
- output_tokens[:, i] = output[-1:]
-
- # Set all tokens after end token to be padding.
- for i in range(1, self.max_output_length):
- indices = output_tokens[:, i - 1] == self.end_index | (
- output_tokens[:, i - 1] == self.pad_index
- )
- output_tokens[indices, i] = self.pad_index
-
- return output_tokens
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index c33f419..da7553d 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
@@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d):
def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
"""3x3 convolution with batch norm."""
- conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ conv3x3 = partial(
+ Conv2dAuto,
+ kernel_size=3,
+ bias=False,
+ )
return nn.Sequential(
conv3x3(in_channels, out_channels, *args, **kwargs),
nn.BatchNorm2d(out_channels),
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index d7e3d08..b10f93a 100644
--- a/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -392,7 +392,12 @@ def load_transducer_loss(
transitions = gtn.load(str(processed_path / transitions))
preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
)
num_tokens = preprocessor.num_tokens
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..93a1e43 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,7 +44,12 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
)
def _build_decompression_block(
@@ -72,8 +77,10 @@ class Decoder(nn.Module):
)
)
- if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ if self.upsampling and i < len(self.upsampling):
+ modules.append(
+ nn.Upsample(size=self.upsampling[i]),
+ )
if dropout is not None:
modules.append(dropout)
@@ -102,7 +109,12 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ nn.Conv2d(
+ self.embedding_dim,
+ channels[0],
+ kernel_size=1,
+ stride=1,
+ )
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..b0cceed 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,5 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-from typing import List, Optional, Tuple, Type
+from typing import Sequence, Optional, Tuple, Type
import torch
from torch import nn
@@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -36,9 +39,9 @@ class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
+ channels: Sequence[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
@@ -77,12 +80,12 @@ class Encoder(nn.Module):
self.num_embeddings, self.embedding_dim, self.beta
)
+ @staticmethod
def _build_compression_block(
- self,
in_channels: int,
channels: int,
- kernel_sizes: List[int],
- strides: List[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
activation: Type[nn.Module],
dropout: Optional[Type[nn.Module]],
) -> nn.ModuleList:
@@ -109,8 +112,8 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: int,
- kernel_sizes: List[int],
- strides: List[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
num_residual_layers: int,
activation: Type[nn.Module],
dropout: Optional[Type[nn.Module]],
@@ -135,7 +138,12 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ nn.Conv2d(
+ channels[-1],
+ self.embedding_dim,
+ kernel_size=1,
+ stride=1,
+ )
)
return nn.Sequential(*encoder)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 50448b4..1f08e5e 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,8 +1,7 @@
"""The VQ-VAE."""
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple
-import torch
from torch import nn
from torch import Tensor
@@ -25,6 +24,8 @@ class VQVAE(nn.Module):
beta: float = 0.25,
activation: str = "leaky_relu",
dropout_rate: float = 0.0,
+ *args: Any,
+ **kwargs: Dict,
) -> None:
super().__init__()