diff options
Diffstat (limited to 'src/text_recognizer/networks/vision_transformer.py')
-rw-r--r-- | src/text_recognizer/networks/vision_transformer.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py index 4d204d3..f227954 100644 --- a/src/text_recognizer/networks/vision_transformer.py +++ b/src/text_recognizer/networks/vision_transformer.py @@ -29,9 +29,9 @@ class VisionTransformer(nn.Module): num_heads: int, max_len: int, expansion_dim: int, - mlp_dim: int, dropout_rate: float, trg_pad_index: int, + mlp_dim: Optional[int] = None, patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), activation: str = "gelu", @@ -46,6 +46,7 @@ class VisionTransformer(nn.Module): self.slidning_window = self._configure_sliding_window() self.character_embedding = nn.Embedding(vocab_size, hidden_dim) self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) + self.mlp_dim = mlp_dim self.use_backbone = False if backbone is None: @@ -54,6 +55,8 @@ class VisionTransformer(nn.Module): ) else: self.backbone = configure_backbone(backbone, backbone_args) + if mlp_dim: + self.mlp = nn.Linear(mlp_dim, hidden_dim) self.use_backbone = True self.transformer = Transformer( @@ -66,13 +69,7 @@ class VisionTransformer(nn.Module): activation, ) - self.head = nn.Sequential( - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, mlp_dim), - nn.GELU(), - nn.Dropout(p=dropout_rate), - nn.Linear(mlp_dim, vocab_size), - ) + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) def _configure_sliding_window(self) -> nn.Sequential: return nn.Sequential( @@ -110,7 +107,11 @@ class VisionTransformer(nn.Module): if self.use_backbone: x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) x = self.backbone(x) - x = rearrange(x, "(b t) h -> b t h", b=b, t=t) + if self.mlp_dim: + x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t) + x = self.mlp(x) + else: + x = rearrange(x, "(b t) h -> b t h", b=b, t=t) else: x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) x = self.linear_projection(x) |