diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-24 22:14:17 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-24 22:14:17 +0100 |
commit | 4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch) | |
tree | 04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer/networks/cnn_transformer.py | |
parent | d691b548cd0b6fc4ea184d64261f633789fee021 (diff) |
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 43e5403..7133c26 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -29,14 +29,22 @@ class CNNTransformer(nn.Module): backbone: str, backbone_args: Optional[Dict] = None, activation: str = "gelu", + pool_kernel: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.trg_pad_index = trg_pad_index self.vocab_size = vocab_size self.backbone = configure_backbone(backbone, backbone_args) + + if pool_kernel is not None: + self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) + else: + self.max_pool = None + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.pos_dropout = nn.Dropout(p=dropout_rate) self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) nn.init.normal_(self.character_embedding.weight, std=0.02) @@ -98,18 +106,23 @@ class CNNTransformer(nn.Module): # If batch dimension is missing, it needs to be added. if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] + src = self.backbone(src) + if self.max_pool is not None: + src = self.max_pool(src) + if self.adaptive_pool is not None: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) else: - src = rearrange(src, "b c h w -> b (w h) c") + src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape src += self.src_position_embedding[:, :t] + src = self.pos_dropout(src) return src |