From dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 14:54:44 +0100 Subject: new updates --- .../networks/cnn_transformer_encoder.py | 73 ++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/text_recognizer/networks/cnn_transformer_encoder.py (limited to 'src/text_recognizer/networks/cnn_transformer_encoder.py') diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py new file mode 100644 index 0000000..93626bf --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer_encoder.py @@ -0,0 +1,73 @@ +"""Network with a CNN backend and a transformer encoder head.""" +from typing import Dict + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformerEncoder(nn.Module): + """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict, + mlp_dim: int, + d_model: int, + nhead: int = 8, + dropout_rate: float = 0.1, + activation: str = "relu", + num_layers: int = 6, + num_classes: int = 80, + num_channels: int = 256, + max_len: int = 97, + ) -> None: + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.dropout_rate = dropout_rate + self.activation = activation + self.num_layers = num_layers + + self.backbone = configure_backbone(backbone, backbone_args) + self.position_encoding = PositionalEncoding(d_model, dropout_rate) + self.encoder = self._configure_encoder() + + self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) + + self.mlp = nn.Linear(mlp_dim, d_model) + + self.head = nn.Linear(d_model, num_classes) + + def _configure_encoder(self) -> nn.TransformerEncoder: + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.nhead, + dropout=self.dropout_rate, + activation=self.activation, + ) + norm = nn.LayerNorm(self.d_model) + return nn.TransformerEncoder( + encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm + ) + + def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: + """Forward pass through the network.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + x = self.conv(self.backbone(x)) + x = rearrange(x, "b c h w -> b c (h w)") + x = self.mlp(x) + x = self.position_encoding(x) + x = rearrange(x, "b c h-> c b h") + x = self.encoder(x) + x = rearrange(x, "c b h-> b c h") + logits = self.head(x) + + return logits -- cgit v1.2.3-70-g09d2