summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer_encoder.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
commit8fdb6435e15703fa5b76df19728d905650ee1aef (patch)
treebe3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer/networks/cnn_transformer_encoder.py
parentdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff)
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer_encoder.py')
-rw-r--r--src/text_recognizer/networks/cnn_transformer_encoder.py73
1 files changed, 0 insertions, 73 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
deleted file mode 100644
index 93626bf..0000000
--- a/src/text_recognizer/networks/cnn_transformer_encoder.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Network with a CNN backend and a transformer encoder head."""
-from typing import Dict
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import PositionalEncoding
-from text_recognizer.networks.util import configure_backbone
-
-
-class CNNTransformerEncoder(nn.Module):
- """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
-
- def __init__(
- self,
- backbone: str,
- backbone_args: Dict,
- mlp_dim: int,
- d_model: int,
- nhead: int = 8,
- dropout_rate: float = 0.1,
- activation: str = "relu",
- num_layers: int = 6,
- num_classes: int = 80,
- num_channels: int = 256,
- max_len: int = 97,
- ) -> None:
- super().__init__()
- self.d_model = d_model
- self.nhead = nhead
- self.dropout_rate = dropout_rate
- self.activation = activation
- self.num_layers = num_layers
-
- self.backbone = configure_backbone(backbone, backbone_args)
- self.position_encoding = PositionalEncoding(d_model, dropout_rate)
- self.encoder = self._configure_encoder()
-
- self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
-
- self.mlp = nn.Linear(mlp_dim, d_model)
-
- self.head = nn.Linear(d_model, num_classes)
-
- def _configure_encoder(self) -> nn.TransformerEncoder:
- encoder_layer = nn.TransformerEncoderLayer(
- d_model=self.d_model,
- nhead=self.nhead,
- dropout=self.dropout_rate,
- activation=self.activation,
- )
- norm = nn.LayerNorm(self.d_model)
- return nn.TransformerEncoder(
- encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
- )
-
- def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
- """Forward pass through the network."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
-
- x = self.conv(self.backbone(x))
- x = rearrange(x, "b c h w -> b c (h w)")
- x = self.mlp(x)
- x = self.position_encoding(x)
- x = rearrange(x, "b c h-> c b h")
- x = self.encoder(x)
- x = rearrange(x, "c b h-> b c h")
- logits = self.head(x)
-
- return logits