diff options
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer_encoder.py')
| -rw-r--r-- | src/text_recognizer/networks/cnn_transformer_encoder.py | 73 | 
1 files changed, 0 insertions, 73 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py deleted file mode 100644 index 93626bf..0000000 --- a/src/text_recognizer/networks/cnn_transformer_encoder.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Network with a CNN backend and a transformer encoder head.""" -from typing import Dict - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding -from text_recognizer.networks.util import configure_backbone - - -class CNNTransformerEncoder(nn.Module): -    """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" - -    def __init__( -        self, -        backbone: str, -        backbone_args: Dict, -        mlp_dim: int, -        d_model: int, -        nhead: int = 8, -        dropout_rate: float = 0.1, -        activation: str = "relu", -        num_layers: int = 6, -        num_classes: int = 80, -        num_channels: int = 256, -        max_len: int = 97, -    ) -> None: -        super().__init__() -        self.d_model = d_model -        self.nhead = nhead -        self.dropout_rate = dropout_rate -        self.activation = activation -        self.num_layers = num_layers - -        self.backbone = configure_backbone(backbone, backbone_args) -        self.position_encoding = PositionalEncoding(d_model, dropout_rate) -        self.encoder = self._configure_encoder() - -        self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) - -        self.mlp = nn.Linear(mlp_dim, d_model) - -        self.head = nn.Linear(d_model, num_classes) - -    def _configure_encoder(self) -> nn.TransformerEncoder: -        encoder_layer = nn.TransformerEncoderLayer( -            d_model=self.d_model, -            nhead=self.nhead, -            dropout=self.dropout_rate, -            activation=self.activation, -        ) -        norm = nn.LayerNorm(self.d_model) -        return nn.TransformerEncoder( -            encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm -        ) - -    def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: -        """Forward pass through the network.""" -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] - -        x = self.conv(self.backbone(x)) -        x = rearrange(x, "b c h w -> b c (h w)") -        x = self.mlp(x) -        x = self.position_encoding(x) -        x = rearrange(x, "b c h-> c b h") -        x = self.encoder(x) -        x = rearrange(x, "c b h-> b c h") -        logits = self.head(x) - -        return logits  |