diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
commit | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch) | |
tree | 1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/networks/cnn_transformer_encoder.py | |
parent | e181195a699d7fa237f256d90ab4dedffc03d405 (diff) |
new updates
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer_encoder.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer_encoder.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py new file mode 100644 index 0000000..93626bf --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer_encoder.py @@ -0,0 +1,73 @@ +"""Network with a CNN backend and a transformer encoder head.""" +from typing import Dict + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformerEncoder(nn.Module): + """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict, + mlp_dim: int, + d_model: int, + nhead: int = 8, + dropout_rate: float = 0.1, + activation: str = "relu", + num_layers: int = 6, + num_classes: int = 80, + num_channels: int = 256, + max_len: int = 97, + ) -> None: + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.dropout_rate = dropout_rate + self.activation = activation + self.num_layers = num_layers + + self.backbone = configure_backbone(backbone, backbone_args) + self.position_encoding = PositionalEncoding(d_model, dropout_rate) + self.encoder = self._configure_encoder() + + self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) + + self.mlp = nn.Linear(mlp_dim, d_model) + + self.head = nn.Linear(d_model, num_classes) + + def _configure_encoder(self) -> nn.TransformerEncoder: + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.nhead, + dropout=self.dropout_rate, + activation=self.activation, + ) + norm = nn.LayerNorm(self.d_model) + return nn.TransformerEncoder( + encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm + ) + + def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: + """Forward pass through the network.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + x = self.conv(self.backbone(x)) + x = rearrange(x, "b c h w -> b c (h w)") + x = self.mlp(x) + x = self.position_encoding(x) + x = rearrange(x, "b c h-> c b h") + x = self.encoder(x) + x = rearrange(x, "c b h-> b c h") + logits = self.head(x) + + return logits |