diff options
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer_encoder.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer_encoder.py | 73 |
1 files changed, 0 insertions, 73 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py deleted file mode 100644 index 93626bf..0000000 --- a/src/text_recognizer/networks/cnn_transformer_encoder.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Network with a CNN backend and a transformer encoder head.""" -from typing import Dict - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding -from text_recognizer.networks.util import configure_backbone - - -class CNNTransformerEncoder(nn.Module): - """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict, - mlp_dim: int, - d_model: int, - nhead: int = 8, - dropout_rate: float = 0.1, - activation: str = "relu", - num_layers: int = 6, - num_classes: int = 80, - num_channels: int = 256, - max_len: int = 97, - ) -> None: - super().__init__() - self.d_model = d_model - self.nhead = nhead - self.dropout_rate = dropout_rate - self.activation = activation - self.num_layers = num_layers - - self.backbone = configure_backbone(backbone, backbone_args) - self.position_encoding = PositionalEncoding(d_model, dropout_rate) - self.encoder = self._configure_encoder() - - self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) - - self.mlp = nn.Linear(mlp_dim, d_model) - - self.head = nn.Linear(d_model, num_classes) - - def _configure_encoder(self) -> nn.TransformerEncoder: - encoder_layer = nn.TransformerEncoderLayer( - d_model=self.d_model, - nhead=self.nhead, - dropout=self.dropout_rate, - activation=self.activation, - ) - norm = nn.LayerNorm(self.d_model) - return nn.TransformerEncoder( - encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm - ) - - def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: - """Forward pass through the network.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - - x = self.conv(self.backbone(x)) - x = rearrange(x, "b c h w -> b c (h w)") - x = self.mlp(x) - x = self.position_encoding(x) - x = rearrange(x, "b c h-> c b h") - x = self.encoder(x) - x = rearrange(x, "c b h-> b c h") - logits = self.head(x) - - return logits |