diff options
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 46 |
1 files changed, 16 insertions, 30 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 3da2c9f..16c7a41 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,4 +1,4 @@ -"""A DETR style transfomers but for text recognition.""" +"""A CNN-Transformer for image to text recognition.""" from typing import Dict, Optional, Tuple from einops import rearrange @@ -11,7 +11,7 @@ from text_recognizer.networks.util import configure_backbone class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR.""" + """CNN+Transfomer for image to sequence prediction.""" def __init__( self, @@ -25,22 +25,14 @@ class CNNTransformer(nn.Module): dropout_rate: float, trg_pad_index: int, backbone: str, - out_channels: int, - max_len: int, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index - self.backbone = configure_backbone(backbone, backbone_args) self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - - # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) - self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) self.adaptive_pool = ( nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None @@ -78,8 +70,12 @@ class CNNTransformer(nn.Module): self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) ) - def preprocess_input(self, src: Tensor) -> Tensor: - """Encodes src with a backbone network and a positional encoding. + def extract_image_features(self, src: Tensor) -> Tensor: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D Args: src (Tensor): Input tensor. @@ -88,29 +84,19 @@ class CNNTransformer(nn.Module): Tensor: A input src to the transformer. """ - # If batch dimenstion is missing, it needs to be added. + # If batch dimension is missing, it needs to be added. if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - # src = self.conv(src) + src = rearrange(src, "b c h w -> b w c h") if self.adaptive_pool is not None: src = self.adaptive_pool(src) - H, W = src.shape[-2:] - src = rearrange(src, "b t h w -> b t (h w)") - - # construct positional encodings - pos = torch.cat( - [ - self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), - self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), - ], - dim=-1, - ).unsqueeze(0) - pos = rearrange(pos, "b h w l -> b l (h w)") - src = pos + 0.1 * src + src = src.squeeze(3) + src = self.position_encoding(src) + return src - def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: """Encodes target tensor with embedding and postion. Args: @@ -126,9 +112,9 @@ class CNNTransformer(nn.Module): def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - h = self.preprocess_input(x) + h = self.extract_image_features(x) trg_mask = self._create_trg_mask(trg) - trg = self.preprocess_target(trg) + trg = self.target_embedding(trg) out = self.transformer(h, trg, trg_mask=trg_mask) logits = self.head(out) |