summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer_encoder.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/networks/cnn_transformer_encoder.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer_encoder.py')
-rw-r--r--src/text_recognizer/networks/cnn_transformer_encoder.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
new file mode 100644
index 0000000..93626bf
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer_encoder.py
@@ -0,0 +1,73 @@
+"""Network with a CNN backend and a transformer encoder head."""
+from typing import Dict
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformerEncoder(nn.Module):
+ """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict,
+ mlp_dim: int,
+ d_model: int,
+ nhead: int = 8,
+ dropout_rate: float = 0.1,
+ activation: str = "relu",
+ num_layers: int = 6,
+ num_classes: int = 80,
+ num_channels: int = 256,
+ max_len: int = 97,
+ ) -> None:
+ super().__init__()
+ self.d_model = d_model
+ self.nhead = nhead
+ self.dropout_rate = dropout_rate
+ self.activation = activation
+ self.num_layers = num_layers
+
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.position_encoding = PositionalEncoding(d_model, dropout_rate)
+ self.encoder = self._configure_encoder()
+
+ self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
+
+ self.mlp = nn.Linear(mlp_dim, d_model)
+
+ self.head = nn.Linear(d_model, num_classes)
+
+ def _configure_encoder(self) -> nn.TransformerEncoder:
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=self.d_model,
+ nhead=self.nhead,
+ dropout=self.dropout_rate,
+ activation=self.activation,
+ )
+ norm = nn.LayerNorm(self.d_model)
+ return nn.TransformerEncoder(
+ encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
+ )
+
+ def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
+ """Forward pass through the network."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ x = self.conv(self.backbone(x))
+ x = rearrange(x, "b c h w -> b c (h w)")
+ x = self.mlp(x)
+ x = self.position_encoding(x)
+ x = rearrange(x, "b c h-> c b h")
+ x = self.encoder(x)
+ x = rearrange(x, "c b h-> b c h")
+ logits = self.head(x)
+
+ return logits