summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/vision_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/vision_transformer.py')
-rw-r--r--src/text_recognizer/networks/vision_transformer.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
index 4d204d3..f227954 100644
--- a/src/text_recognizer/networks/vision_transformer.py
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -29,9 +29,9 @@ class VisionTransformer(nn.Module):
num_heads: int,
max_len: int,
expansion_dim: int,
- mlp_dim: int,
dropout_rate: float,
trg_pad_index: int,
+ mlp_dim: Optional[int] = None,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
activation: str = "gelu",
@@ -46,6 +46,7 @@ class VisionTransformer(nn.Module):
self.slidning_window = self._configure_sliding_window()
self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+ self.mlp_dim = mlp_dim
self.use_backbone = False
if backbone is None:
@@ -54,6 +55,8 @@ class VisionTransformer(nn.Module):
)
else:
self.backbone = configure_backbone(backbone, backbone_args)
+ if mlp_dim:
+ self.mlp = nn.Linear(mlp_dim, hidden_dim)
self.use_backbone = True
self.transformer = Transformer(
@@ -66,13 +69,7 @@ class VisionTransformer(nn.Module):
activation,
)
- self.head = nn.Sequential(
- nn.LayerNorm(hidden_dim),
- nn.Linear(hidden_dim, mlp_dim),
- nn.GELU(),
- nn.Dropout(p=dropout_rate),
- nn.Linear(mlp_dim, vocab_size),
- )
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
@@ -110,7 +107,11 @@ class VisionTransformer(nn.Module):
if self.use_backbone:
x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
x = self.backbone(x)
- x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+ if self.mlp_dim:
+ x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t)
+ x = self.mlp(x)
+ else:
+ x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
else:
x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
x = self.linear_projection(x)