summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/cnn_transformer.py')
-rw-r--r--text_recognizer/networks/cnn_transformer.py37
1 files changed, 27 insertions, 10 deletions
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index e23a15d..d42c29d 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000
class CNNTransformer(nn.Module):
def __init__(
self,
- input_shape: Sequence[int],
- output_shape: Sequence[int],
+ input_dim: Sequence[int],
+ output_dims: Sequence[int],
encoder: Union[DictConfig, Dict],
vocab_size: Optional[int] = None,
num_decoder_layers: int = 4,
@@ -43,22 +43,29 @@ class CNNTransformer(nn.Module):
expansion_dim: int = 1024,
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
+ *args,
+ **kwargs,
) -> None:
+ super().__init__()
self.vocab_size = (
NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
)
+ self.pad_index = 3 # TODO: fix me
self.hidden_dim = hidden_dim
- self.max_output_length = output_shape[0]
+ self.max_output_length = output_dims[0]
# Image backbone
self.encoder = self._configure_encoder(encoder)
+ self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
self.feature_map_encoding = PositionalEncoding2D(
- hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
+ hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
)
# Target token embedding
self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.trg_position_encoding = PositionalEncoding(
+ hidden_dim, dropout_rate, max_len=output_dims[0]
+ )
# Transformer decoder
self.decoder = Decoder(
@@ -86,24 +93,25 @@ class CNNTransformer(nn.Module):
self.head.weight.data.uniform_(-0.1, 0.1)
nn.init.kaiming_normal_(
- self.feature_map_encoding.weight.data,
+ self.encoder_proj.weight.data,
a=0,
mode="fan_out",
nonlinearity="relu",
)
- if self.feature_map_encoding.bias is not None:
+ if self.encoder_proj.bias is not None:
_, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.feature_map_encoding.weight.data
+ self.encoder_proj.weight.data
)
bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+ nn.init.normal_(self.encoder_proj.bias, -bound, bound)
@staticmethod
def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
encoder = OmegaConf.create(encoder)
+ args = encoder.args or {}
network_module = importlib.import_module("text_recognizer.networks")
encoder_class = getattr(network_module, encoder.type)
- return encoder_class(**encoder.args)
+ return encoder_class(**args)
def encode(self, image: Tensor) -> Tensor:
"""Extracts image features with backbone.
@@ -121,6 +129,7 @@ class CNNTransformer(nn.Module):
"""
# Extract image features.
image_features = self.encoder(image)
+ image_features = self.encoder_proj(image_features)
# Add 2d encoding to the feature maps.
image_features = self.feature_map_encoding(image_features)
@@ -133,11 +142,19 @@ class CNNTransformer(nn.Module):
"""Decodes image features with transformer decoder."""
trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = rearrange(trg, "b t d -> t b d")
trg = self.trg_position_encoding(trg)
+ trg = rearrange(trg, "t b d -> b t d")
out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
logits = self.head(out)
return logits
+ def forward(self, image: Tensor, trg: Tensor) -> Tensor:
+ image_features = self.encode(image)
+ output = self.decode(image_features, trg)
+ output = rearrange(output, "b t c -> b c t")
+ return output
+
def predict(self, image: Tensor) -> Tensor:
"""Transcribes text in image(s)."""
bsz = image.shape[0]