summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/image_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:46 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:46 +0200
commit9e54591b7e342edc93b0bb04809a0f54045c6a15 (patch)
treea0f8ba9a72389e65d306c5733cbc6bbc36ea2fcf /text_recognizer/networks/image_transformer.py
parent2d4714fcfeb8914f240a0d36d938b434e82f191b (diff)
black reformatting
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r--text_recognizer/networks/image_transformer.py42
1 files changed, 23 insertions, 19 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index 5a093dc..b9254c9 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -52,8 +52,10 @@ class ImageTransformer(nn.Module):
# Image backbone
self.backbone = backbone
- self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2])
-
+ self.latent_encoding = PositionalEncoding2D(
+ hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
+ )
+
# Target token embedding
self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
@@ -83,16 +85,22 @@ class ImageTransformer(nn.Module):
self.head.bias.data.zero_()
self.head.weight.data.uniform_(-0.1, 0.1)
- nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu")
+ nn.init.kaiming_normal_(
+ self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu"
+ )
if self.latent_encoding.bias is not None:
- _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data)
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(
+ self.latent_encoding.weight.data
+ )
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.latent_encoding.bias, -bound, bound)
- def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]:
+ def _configure_mapping(
+ self, mapping: Optional[List[str]]
+ ) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
if mapping is None:
- mapping, inverse_mapping, _ = emnist_mapping()
+ mapping, inverse_mapping, _ = emnist_mapping()
return mapping, inverse_mapping
def encode(self, image: Tensor) -> Tensor:
@@ -114,7 +122,7 @@ class ImageTransformer(nn.Module):
# Add 2d encoding to the feature maps.
latent = self.latent_encoding(latent)
-
+
# Collapse features maps height and width.
latent = rearrange(latent, "b c h w -> b (h w) c")
return latent
@@ -133,7 +141,11 @@ class ImageTransformer(nn.Module):
bsz = image.shape[0]
image_features = self.encode(image)
- output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long()
+ output_tokens = (
+ (torch.ones((bsz, self.max_output_length)) * self.pad_index)
+ .type_as(image)
+ .long()
+ )
output_tokens[:, 0] = self.start_index
for i in range(1, self.max_output_length):
trg = output_tokens[:, :i]
@@ -143,17 +155,9 @@ class ImageTransformer(nn.Module):
# Set all tokens after end token to be padding.
for i in range(1, self.max_output_length):
- indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index))
+ indices = output_tokens[:, i - 1] == self.end_index | (
+ output_tokens[:, i - 1] == self.pad_index
+ )
output_tokens[indices, i] = self.pad_index
return output_tokens
-
-
-
-
-
-
-
-
-
-