summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer/networks/cnn_transformer.py
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 43e5403..7133c26 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -29,14 +29,22 @@ class CNNTransformer(nn.Module):
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
+ pool_kernel: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
+
+ if pool_kernel is not None:
+ self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
+ else:
+ self.max_pool = None
+
self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.pos_dropout = nn.Dropout(p=dropout_rate)
self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
nn.init.normal_(self.character_embedding.weight, std=0.02)
@@ -98,18 +106,23 @@ class CNNTransformer(nn.Module):
# If batch dimension is missing, it needs to be added.
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
+
src = self.backbone(src)
+ if self.max_pool is not None:
+ src = self.max_pool(src)
+
if self.adaptive_pool is not None:
src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
src = src.squeeze(3)
else:
- src = rearrange(src, "b c h w -> b (w h) c")
+ src = rearrange(src, "b c h w -> b (h w) c")
b, t, _ = src.shape
src += self.src_position_embedding[:, :t]
+ src = self.pos_dropout(src)
return src