From beeaef529e7c893a3475fe27edc880e283373725 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 12:41:04 +0100 Subject: Trying to get the CNNTransformer to work, but it is hard. --- src/text_recognizer/networks/__init__.py | 2 + src/text_recognizer/networks/cnn_transformer.py | 58 ++++++++++++----- .../networks/cnn_transformer_encoder.py | 73 ++++++++++++++++++++++ src/text_recognizer/networks/crnn.py | 40 ++++++++---- src/text_recognizer/networks/ctc.py | 2 +- src/text_recognizer/networks/densenet.py | 4 +- src/text_recognizer/networks/loss.py | 39 +++++++++++- .../networks/transformer/positional_encoding.py | 1 + .../networks/transformer/sparse_transformer.py | 1 - .../networks/transformer/transformer.py | 1 + src/text_recognizer/networks/util.py | 4 +- src/text_recognizer/networks/vision_transformer.py | 19 +++--- 12 files changed, 197 insertions(+), 47 deletions(-) create mode 100644 src/text_recognizer/networks/cnn_transformer_encoder.py delete mode 100644 src/text_recognizer/networks/transformer/sparse_transformer.py (limited to 'src/text_recognizer/networks') diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 8b87797..6d88768 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,5 +1,6 @@ """Network modules.""" from .cnn_transformer import CNNTransformer +from .cnn_transformer_encoder import CNNTransformerEncoder from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet @@ -15,6 +16,7 @@ from .wide_resnet import WideResidualNetwork __all__ = [ "CNNTransformer", + "CNNTransformerEncoder", "ConvolutionalRecurrentNetwork", "DenseNet", "EmbeddingLoss", diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 8666f11..3da2c9f 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,8 +1,7 @@ """A DETR style transfomers but for text recognition.""" -from typing import Dict, Optional, Tuple, Type +from typing import Dict, Optional, Tuple -from einops.layers.torch import Rearrange -from loguru import logger +from einops import rearrange import torch from torch import nn from torch import Tensor @@ -21,23 +20,32 @@ class CNNTransformer(nn.Module): hidden_dim: int, vocab_size: int, num_heads: int, - max_len: int, + adaptive_pool_dim: Tuple, expansion_dim: int, dropout_rate: float, trg_pad_index: int, backbone: str, + out_channels: int, + max_len: int, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index - self.backbone_args = backbone_args + self.backbone = configure_backbone(backbone, backbone_args) self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) - self.collapse_spatial_dim = nn.Sequential( - Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim)) + + # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1) + + self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) + self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None ) + self.transformer = Transformer( num_encoder_layers, num_decoder_layers, @@ -47,7 +55,8 @@ class CNNTransformer(nn.Module): dropout_rate, activation, ) - self.head = nn.Linear(hidden_dim, vocab_size) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) def _create_trg_mask(self, trg: Tensor) -> Tensor: # Move this outside the transformer. @@ -83,8 +92,22 @@ class CNNTransformer(nn.Module): if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - src = self.collapse_spatial_dim(src) - src = self.position_encoding(src) + # src = self.conv(src) + if self.adaptive_pool is not None: + src = self.adaptive_pool(src) + H, W = src.shape[-2:] + src = rearrange(src, "b t h w -> b t (h w)") + + # construct positional encodings + pos = torch.cat( + [ + self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), + self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), + ], + dim=-1, + ).unsqueeze(0) + pos = rearrange(pos, "b h w l -> b l (h w)") + src = pos + 0.1 * src return src def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: @@ -97,15 +120,16 @@ class CNNTransformer(nn.Module): Tuple[Tensor, Tensor]: Encoded target tensor and target mask. """ - trg_mask = self._create_trg_mask(trg) trg = self.character_embedding(trg.long()) trg = self.position_encoding(trg) - return trg, trg_mask + return trg - def forward(self, x: Tensor, trg: Tensor) -> Tensor: + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - src = self.preprocess_input(x) - trg, trg_mask = self.preprocess_target(trg) - out = self.transformer(src, trg, trg_mask=trg_mask) + h = self.preprocess_input(x) + trg_mask = self._create_trg_mask(trg) + trg = self.preprocess_target(trg) + out = self.transformer(h, trg, trg_mask=trg_mask) + logits = self.head(out) return logits diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py new file mode 100644 index 0000000..93626bf --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer_encoder.py @@ -0,0 +1,73 @@ +"""Network with a CNN backend and a transformer encoder head.""" +from typing import Dict + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformerEncoder(nn.Module): + """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict, + mlp_dim: int, + d_model: int, + nhead: int = 8, + dropout_rate: float = 0.1, + activation: str = "relu", + num_layers: int = 6, + num_classes: int = 80, + num_channels: int = 256, + max_len: int = 97, + ) -> None: + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.dropout_rate = dropout_rate + self.activation = activation + self.num_layers = num_layers + + self.backbone = configure_backbone(backbone, backbone_args) + self.position_encoding = PositionalEncoding(d_model, dropout_rate) + self.encoder = self._configure_encoder() + + self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) + + self.mlp = nn.Linear(mlp_dim, d_model) + + self.head = nn.Linear(d_model, num_classes) + + def _configure_encoder(self) -> nn.TransformerEncoder: + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.nhead, + dropout=self.dropout_rate, + activation=self.activation, + ) + norm = nn.LayerNorm(self.d_model) + return nn.TransformerEncoder( + encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm + ) + + def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: + """Forward pass through the network.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + x = self.conv(self.backbone(x)) + x = rearrange(x, "b c h w -> b c (h w)") + x = self.mlp(x) + x = self.position_encoding(x) + x = rearrange(x, "b c h-> c b h") + x = self.encoder(x) + x = rearrange(x, "c b h-> b c h") + logits = self.head(x) + + return logits diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py index 3e605e2..9747429 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/src/text_recognizer/networks/crnn.py @@ -1,12 +1,9 @@ """LSTM with CTC for handwritten text recognition within a line.""" -import importlib -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Tuple from einops import rearrange, reduce -from einops.layers.torch import Rearrange, Reduce +from einops.layers.torch import Rearrange from loguru import logger -import torch from torch import nn from torch import Tensor @@ -28,16 +25,21 @@ class ConvolutionalRecurrentNetwork(nn.Module): patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), recurrent_cell: str = "lstm", + avg_pool: bool = False, + use_sliding_window: bool = True, ) -> None: super().__init__() self.backbone_args = backbone_args or {} self.patch_size = patch_size self.stride = stride - self.sliding_window = self._configure_sliding_window() + self.sliding_window = ( + self._configure_sliding_window() if use_sliding_window else None + ) self.input_size = input_size self.hidden_size = hidden_size self.backbone = configure_backbone(backbone, backbone_args) self.bidirectional = bidirectional + self.avg_pool = avg_pool if recurrent_cell.upper() in ["LSTM", "GRU"]: recurrent_cell = getattr(nn, recurrent_cell) @@ -76,15 +78,27 @@ class ConvolutionalRecurrentNetwork(nn.Module): """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" if len(x.shape) < 4: x = x[(None,) * (4 - len(x.shape))] - x = self.sliding_window(x) - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.backbone(x) + if self.sliding_window is not None: + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - # Avgerage pooling. - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + x = self.backbone(x) + + # Avgerage pooling. + if self.avg_pool: + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + else: + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) + else: + # Encode the entire image with a CNN, and use the channels as temporal dimension. + b = x.shape[0] + x = self.backbone(x) + x = rearrange(x, "b c h w -> c b (h w)", b=b) # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 2493d5c..af9b700 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -33,7 +33,7 @@ def greedy_decoder( """ if character_mapper is None: - character_mapper = EmnistMapper() + character_mapper = EmnistMapper(pad_token="_") # noqa: S106 predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") decoded_predictions = [] diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py index d2aad60..7dc58d9 100644 --- a/src/text_recognizer/networks/densenet.py +++ b/src/text_recognizer/networks/densenet.py @@ -72,7 +72,7 @@ class _DenseBlock(nn.Module): ) -> None: super().__init__() self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation + num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, ) def _build_dense_blocks( @@ -219,7 +219,7 @@ class DenseNet(nn.Module): def forward(self, x: Tensor) -> Tensor: """Forward pass of Densenet.""" - # If batch dimenstion is missing, it needs to be added. + # If batch dimenstion is missing, it will be added. if len(x.shape) < 4: x = x[(None,) * (4 - len(x.shape))] return self.densenet(x) diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py index ff843cf..cf9fa0d 100644 --- a/src/text_recognizer/networks/loss.py +++ b/src/text_recognizer/networks/loss.py @@ -1,10 +1,12 @@ """Implementations of custom loss functions.""" from pytorch_metric_learning import distances, losses, miners, reducers +import torch from torch import nn from torch import Tensor +from torch.autograd import Variable +import torch.nn.functional as F - -__all__ = ["EmbeddingLoss"] +__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] class EmbeddingLoss: @@ -32,3 +34,36 @@ class EmbeddingLoss: hard_pairs = self.miner(embeddings, labels) loss = self.loss_fn(embeddings, labels, hard_pairs) return loss + + +class LabelSmoothingCrossEntropy(nn.Module): + """Label smoothing loss function.""" + + def __init__( + self, + classes: int, + smoothing: float = 0.0, + ignore_index: int = None, + dim: int = -1, + ) -> None: + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.ignore_index = ignore_index + self.cls = classes + self.dim = dim + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """Calculates the loss.""" + pred = pred.log_softmax(dim=self.dim) + with torch.no_grad(): + # true_dist = pred.data.clone() + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (self.cls - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + if self.ignore_index is not None: + true_dist[:, self.ignore_index] = 0 + mask = torch.nonzero(target == self.ignore_index, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py index a47141b..1ba5537 100644 --- a/src/text_recognizer/networks/transformer/positional_encoding.py +++ b/src/text_recognizer/networks/transformer/positional_encoding.py @@ -13,6 +13,7 @@ class PositionalEncoding(nn.Module): ) -> None: super().__init__() self.dropout = nn.Dropout(p=dropout_rate) + self.max_len = max_len pe = torch.zeros(max_len, hidden_dim) position = torch.arange(0, max_len).unsqueeze(1) diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py deleted file mode 100644 index 8c391c8..0000000 --- a/src/text_recognizer/networks/transformer/sparse_transformer.py +++ /dev/null @@ -1 +0,0 @@ -"""Encoder and Decoder modules using spares activations.""" diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py index 1c9c7dd..c6e943e 100644 --- a/src/text_recognizer/networks/transformer/transformer.py +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -230,6 +230,7 @@ class Transformer(nn.Module): ) -> Tensor: """Forward pass through the transformer.""" if src.shape[0] != trg.shape[0]: + print(trg.shape) raise RuntimeError("The batch size of the src and trg must be the same.") if src.shape[2] != trg.shape[2]: raise RuntimeError( diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index 0d08506..b31e640 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -28,7 +28,7 @@ def sliding_window( c = images.shape[1] patches = unfold(images) patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], ) return patches @@ -77,7 +77,7 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: backbone = nn.Sequential( - *list(backbone.children())[0][: -backbone_args["remove_layers"]] + *list(backbone.children())[:][: -backbone_args["remove_layers"]] ) return backbone diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py index 4d204d3..f227954 100644 --- a/src/text_recognizer/networks/vision_transformer.py +++ b/src/text_recognizer/networks/vision_transformer.py @@ -29,9 +29,9 @@ class VisionTransformer(nn.Module): num_heads: int, max_len: int, expansion_dim: int, - mlp_dim: int, dropout_rate: float, trg_pad_index: int, + mlp_dim: Optional[int] = None, patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), activation: str = "gelu", @@ -46,6 +46,7 @@ class VisionTransformer(nn.Module): self.slidning_window = self._configure_sliding_window() self.character_embedding = nn.Embedding(vocab_size, hidden_dim) self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) + self.mlp_dim = mlp_dim self.use_backbone = False if backbone is None: @@ -54,6 +55,8 @@ class VisionTransformer(nn.Module): ) else: self.backbone = configure_backbone(backbone, backbone_args) + if mlp_dim: + self.mlp = nn.Linear(mlp_dim, hidden_dim) self.use_backbone = True self.transformer = Transformer( @@ -66,13 +69,7 @@ class VisionTransformer(nn.Module): activation, ) - self.head = nn.Sequential( - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, mlp_dim), - nn.GELU(), - nn.Dropout(p=dropout_rate), - nn.Linear(mlp_dim, vocab_size), - ) + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) def _configure_sliding_window(self) -> nn.Sequential: return nn.Sequential( @@ -110,7 +107,11 @@ class VisionTransformer(nn.Module): if self.use_backbone: x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) x = self.backbone(x) - x = rearrange(x, "(b t) h -> b t h", b=b, t=t) + if self.mlp_dim: + x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t) + x = self.mlp(x) + else: + x = rearrange(x, "(b t) h -> b t h", b=b, t=t) else: x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) x = self.linear_projection(x) -- cgit v1.2.3-70-g09d2