summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/image_transformer.py42
-rw-r--r--text_recognizer/networks/transformer/__init__.py6
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py5
3 files changed, 30 insertions, 23 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index 5a093dc..b9254c9 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -52,8 +52,10 @@ class ImageTransformer(nn.Module):
# Image backbone
self.backbone = backbone
- self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2])
-
+ self.latent_encoding = PositionalEncoding2D(
+ hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
+ )
+
# Target token embedding
self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
@@ -83,16 +85,22 @@ class ImageTransformer(nn.Module):
self.head.bias.data.zero_()
self.head.weight.data.uniform_(-0.1, 0.1)
- nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu")
+ nn.init.kaiming_normal_(
+ self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu"
+ )
if self.latent_encoding.bias is not None:
- _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data)
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(
+ self.latent_encoding.weight.data
+ )
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.latent_encoding.bias, -bound, bound)
- def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]:
+ def _configure_mapping(
+ self, mapping: Optional[List[str]]
+ ) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
if mapping is None:
- mapping, inverse_mapping, _ = emnist_mapping()
+ mapping, inverse_mapping, _ = emnist_mapping()
return mapping, inverse_mapping
def encode(self, image: Tensor) -> Tensor:
@@ -114,7 +122,7 @@ class ImageTransformer(nn.Module):
# Add 2d encoding to the feature maps.
latent = self.latent_encoding(latent)
-
+
# Collapse features maps height and width.
latent = rearrange(latent, "b c h w -> b (h w) c")
return latent
@@ -133,7 +141,11 @@ class ImageTransformer(nn.Module):
bsz = image.shape[0]
image_features = self.encode(image)
- output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long()
+ output_tokens = (
+ (torch.ones((bsz, self.max_output_length)) * self.pad_index)
+ .type_as(image)
+ .long()
+ )
output_tokens[:, 0] = self.start_index
for i in range(1, self.max_output_length):
trg = output_tokens[:, :i]
@@ -143,17 +155,9 @@ class ImageTransformer(nn.Module):
# Set all tokens after end token to be padding.
for i in range(1, self.max_output_length):
- indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index))
+ indices = output_tokens[:, i - 1] == self.end_index | (
+ output_tokens[:, i - 1] == self.pad_index
+ )
output_tokens[indices, i] = self.pad_index
return output_tokens
-
-
-
-
-
-
-
-
-
-
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 139cd23..652e82e 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,7 @@
"""Transformer modules."""
-from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask
+from .positional_encoding import (
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index dbde887..5874e97 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -71,12 +71,11 @@ class PositionalEncoding2D(nn.Module):
x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
+
def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
"""Returns causal target mask."""
trg_pad_mask = (trg != pad_index)[:, None, None]
trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
+ trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask