1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
"""Network with a CNN backend and a transformer encoder head."""
from typing import Dict
from einops import rearrange
import torch
from torch import nn
from torch import Tensor
from text_recognizer.networks.transformer import PositionalEncoding
from text_recognizer.networks.util import configure_backbone
class CNNTransformerEncoder(nn.Module):
"""A CNN backbone with Transformer Encoder frontend for sequence prediction."""
def __init__(
self,
backbone: str,
backbone_args: Dict,
mlp_dim: int,
d_model: int,
nhead: int = 8,
dropout_rate: float = 0.1,
activation: str = "relu",
num_layers: int = 6,
num_classes: int = 80,
num_channels: int = 256,
max_len: int = 97,
) -> None:
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.dropout_rate = dropout_rate
self.activation = activation
self.num_layers = num_layers
self.backbone = configure_backbone(backbone, backbone_args)
self.position_encoding = PositionalEncoding(d_model, dropout_rate)
self.encoder = self._configure_encoder()
self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
self.mlp = nn.Linear(mlp_dim, d_model)
self.head = nn.Linear(d_model, num_classes)
def _configure_encoder(self) -> nn.TransformerEncoder:
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=self.nhead,
dropout=self.dropout_rate,
activation=self.activation,
)
norm = nn.LayerNorm(self.d_model)
return nn.TransformerEncoder(
encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
)
def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
"""Forward pass through the network."""
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
x = self.conv(self.backbone(x))
x = rearrange(x, "b c h w -> b c (h w)")
x = self.mlp(x)
x = self.position_encoding(x)
x = rearrange(x, "b c h-> c b h")
x = self.encoder(x)
x = rearrange(x, "c b h-> b c h")
logits = self.head(x)
return logits
|