diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-08 22:24:28 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-08 22:24:28 +0200 |
commit | d717f9e3e4dd17351f08c5822cb90d055c4513cc (patch) | |
tree | 29b10328323a5468e20f03798a09628b8edbfc5a /text_recognizer/networks | |
parent | eb5b206f7e1b08435378d2a02395307be55ee6f1 (diff) |
Working on cnn transformer module
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 123 |
1 files changed, 121 insertions, 2 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index da69311..38de0ba 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -10,5 +10,124 @@ class CnnTransformer(nn.Module): def __attrs_pre_init__(self) -> None: super().__init__() - backbone: Type[nn.Module] = attr.ib() - head = Type[nn.Module] = attr.ib() + # Parameters, + input_dims: Tuple[int, int, int] = attr.ib() + hidden_dim: int = attr.ib() + dropout_rate: float = attr.ib() + max_output_len: int = attr.ib() + num_classes: int = attr.ib() + + # Modules. + encoder: Type[nn.Module] = attr.ib() + decoder: Type[nn.Module] = attr.ib() + embedding: nn.Embedding = attr.ib(init=False, default=None) + latent_encoder: nn.Sequential = attr.ib(init=False, default=None) + token_embedding: nn.Embedding = attr.ib(init=False, default=None) + token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None) + head: nn.Linear = attr.ib(init=False, default=None) + + def __attrs_post_init__(self) -> None: + # Latent projector for down sampling number of filters and 2d + # positional encoding. + self.latent_encoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ), + PositionalEncoding2D( + hidden_dim=self.hidden_dim, + max_h=self.input_dims[1], + max_w=self.input_dims[2], + ), + nn.Flatten(start_dim=2), + ) + + # Token embedding. + self.token_embedding = nn.Embedding( + num_embeddings=self.num_classes, embedding_dim=self.hidden_dim + ) + + # Positional encoding for decoder tokens. + self.token_pos_encoder = PositionalEncoding( + hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate + ) + # Head + self.head = nn.Linear( + in_features=self.hidden_dim, out_features=self.num_classes + ) + + # Initalize weights for encoder. + self.init_weights() + + def init_weights(self) -> None: + """Initalize weights for decoder network and head.""" + bound = 0.1 + self.token_embedding.weight.data.uniform_(-bound, bound) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-bound, bound) + # TODO: Initalize encoder? + + def encode(self, x: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + x (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, E)` + + where Sx is the length of the flattened feature maps projected from + the encoder. E latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(x) + z = self.latent_encoder(z) + + # Permute tensor from [B, E, Ho * Wo] to [Sx, B, E] + z = z.permute(2, 0, 1) + return z + + def decode(self, z: Tensor, trg: Tensor) -> Tensor: + """Decodes latent images embedding into word pieces. + + Args: + z (Tensor): Latent images embedding. + trg (Tensor): Word embeddings. + + Shapes: + - z: :math: `(B, Sx, E)` + - out: :math: `(B, Sy, T)` + + where Sy is the length of the output and T is the number of tokens. + + Returns: + Tensor: Sequence of word piece embeddings. + """ + pass + + def forward(self, x: Tensor, trg: Tensor) -> Tensor: + """Encodes images into word piece logtis. + + Args: + x (Tensor): Input image(s). + trg (Tensor): Target word embeddings. + + Shapes: + - x: :math: `(B, C, H, W)` + - trg: :math: `(B, Sy, T)` + + where B is the batch size, C is the number of input channels, H is + the image height and W is the image width. + """ + z = self.encode(x) + y = self.decode(z, trg) + return y + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image.""" + pass |