diff options
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 42 |
1 files changed, 19 insertions, 23 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 09cc654..f3ba49d 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -2,7 +2,6 @@ import math from typing import Tuple -import attr from torch import nn, Tensor from text_recognizer.networks.encoders.efficientnet import EfficientNet @@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s(eq=False) class ConvTransformer(nn.Module): """Convolutional encoder and transformer decoder network.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + dropout_rate: float, + num_classes: int, + pad_index: Tensor, + encoder: EfficientNet, + decoder: Decoder, + ) -> None: super().__init__() + self.input_dims = input_dims + self.hidden_dim = hidden_dim + self.dropout_rate = dropout_rate + self.num_classes = num_classes + self.pad_index = pad_index + self.encoder = encoder + self.decoder = decoder - # Parameters and placeholders, - input_dims: Tuple[int, int, int] = attr.ib() - hidden_dim: int = attr.ib() - dropout_rate: float = attr.ib() - max_output_len: int = attr.ib() - num_classes: int = attr.ib() - pad_index: Tensor = attr.ib() - - # Modules. - encoder: EfficientNet = attr.ib() - decoder: Decoder = attr.ib() - - latent_encoder: nn.Sequential = attr.ib(init=False) - token_embedding: nn.Embedding = attr.ib(init=False) - token_pos_encoder: PositionalEncoding = attr.ib(init=False) - head: nn.Linear = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -126,7 +121,8 @@ class ConvTransformer(nn.Module): context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) out = self.decoder(x=context, context=z, mask=context_mask) - logits = self.head(out) + logits = self.head(out) # [B, Sy, T] + logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits def forward(self, x: Tensor, context: Tensor) -> Tensor: |