summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer_encoder.py
blob: 93626bfd07502931a72c80f27accd8b88eb55065 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Network with a CNN backend and a transformer encoder head."""
from typing import Dict

from einops import rearrange
import torch
from torch import nn
from torch import Tensor

from text_recognizer.networks.transformer import PositionalEncoding
from text_recognizer.networks.util import configure_backbone


class CNNTransformerEncoder(nn.Module):
    """A CNN backbone with Transformer Encoder frontend for sequence prediction."""

    def __init__(
        self,
        backbone: str,
        backbone_args: Dict,
        mlp_dim: int,
        d_model: int,
        nhead: int = 8,
        dropout_rate: float = 0.1,
        activation: str = "relu",
        num_layers: int = 6,
        num_classes: int = 80,
        num_channels: int = 256,
        max_len: int = 97,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.dropout_rate = dropout_rate
        self.activation = activation
        self.num_layers = num_layers

        self.backbone = configure_backbone(backbone, backbone_args)
        self.position_encoding = PositionalEncoding(d_model, dropout_rate)
        self.encoder = self._configure_encoder()

        self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)

        self.mlp = nn.Linear(mlp_dim, d_model)

        self.head = nn.Linear(d_model, num_classes)

    def _configure_encoder(self) -> nn.TransformerEncoder:
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dropout=self.dropout_rate,
            activation=self.activation,
        )
        norm = nn.LayerNorm(self.d_model)
        return nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
        )

    def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
        """Forward pass through the network."""
        if len(x.shape) < 4:
            x = x[(None,) * (4 - len(x.shape))]

        x = self.conv(self.backbone(x))
        x = rearrange(x, "b c h w -> b c (h w)")
        x = self.mlp(x)
        x = self.position_encoding(x)
        x = rearrange(x, "b c h-> c b h")
        x = self.encoder(x)
        x = rearrange(x, "c b h-> b c h")
        logits = self.head(x)

        return logits