diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 42 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 5 |
3 files changed, 30 insertions, 23 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 5a093dc..b9254c9 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -52,8 +52,10 @@ class ImageTransformer(nn.Module): # Image backbone self.backbone = backbone - self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]) - + self.latent_encoding = PositionalEncoding2D( + hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] + ) + # Target token embedding self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) @@ -83,16 +85,22 @@ class ImageTransformer(nn.Module): self.head.bias.data.zero_() self.head.weight.data.uniform_(-0.1, 0.1) - nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu") + nn.init.kaiming_normal_( + self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu" + ) if self.latent_encoding.bias is not None: - _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data) + _, fan_out = nn.init._calculate_fan_in_and_fan_out( + self.latent_encoding.weight.data + ) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.latent_encoding.bias, -bound, bound) - def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]: + def _configure_mapping( + self, mapping: Optional[List[str]] + ) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" if mapping is None: - mapping, inverse_mapping, _ = emnist_mapping() + mapping, inverse_mapping, _ = emnist_mapping() return mapping, inverse_mapping def encode(self, image: Tensor) -> Tensor: @@ -114,7 +122,7 @@ class ImageTransformer(nn.Module): # Add 2d encoding to the feature maps. latent = self.latent_encoding(latent) - + # Collapse features maps height and width. latent = rearrange(latent, "b c h w -> b (h w) c") return latent @@ -133,7 +141,11 @@ class ImageTransformer(nn.Module): bsz = image.shape[0] image_features = self.encode(image) - output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long() + output_tokens = ( + (torch.ones((bsz, self.max_output_length)) * self.pad_index) + .type_as(image) + .long() + ) output_tokens[:, 0] = self.start_index for i in range(1, self.max_output_length): trg = output_tokens[:, :i] @@ -143,17 +155,9 @@ class ImageTransformer(nn.Module): # Set all tokens after end token to be padding. for i in range(1, self.max_output_length): - indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index)) + indices = output_tokens[:, i - 1] == self.end_index | ( + output_tokens[:, i - 1] == self.pad_index + ) output_tokens[indices, i] = self.pad_index return output_tokens - - - - - - - - - - diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 139cd23..652e82e 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,7 @@ """Transformer modules.""" -from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask +from .positional_encoding import ( + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index dbde887..5874e97 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -71,12 +71,11 @@ class PositionalEncoding2D(nn.Module): x += self.pe[:, : x.shape[2], : x.shape[3]] return x + def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: """Returns causal target mask.""" trg_pad_mask = (trg != pad_index)[:, None, None] trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() + trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() trg_mask = trg_pad_mask & trg_sub_mask return trg_mask |