From 9426cc794d8c28a65bbbf5ae5466a0a343078558 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 25 Apr 2021 23:32:50 +0200 Subject: Efficient net and non working transformer model. --- text_recognizer/networks/cnn_transformer.py | 37 +++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) (limited to 'text_recognizer/networks/cnn_transformer.py') diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index e23a15d..d42c29d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000 class CNNTransformer(nn.Module): def __init__( self, - input_shape: Sequence[int], - output_shape: Sequence[int], + input_dim: Sequence[int], + output_dims: Sequence[int], encoder: Union[DictConfig, Dict], vocab_size: Optional[int] = None, num_decoder_layers: int = 4, @@ -43,22 +43,29 @@ class CNNTransformer(nn.Module): expansion_dim: int = 1024, dropout_rate: float = 0.1, transformer_activation: str = "glu", + *args, + **kwargs, ) -> None: + super().__init__() self.vocab_size = ( NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size ) + self.pad_index = 3 # TODO: fix me self.hidden_dim = hidden_dim - self.max_output_length = output_shape[0] + self.max_output_length = output_dims[0] # Image backbone self.encoder = self._configure_encoder(encoder) + self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] + hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] ) # Target token embedding self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.trg_position_encoding = PositionalEncoding( + hidden_dim, dropout_rate, max_len=output_dims[0] + ) # Transformer decoder self.decoder = Decoder( @@ -86,24 +93,25 @@ class CNNTransformer(nn.Module): self.head.weight.data.uniform_(-0.1, 0.1) nn.init.kaiming_normal_( - self.feature_map_encoding.weight.data, + self.encoder_proj.weight.data, a=0, mode="fan_out", nonlinearity="relu", ) - if self.feature_map_encoding.bias is not None: + if self.encoder_proj.bias is not None: _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.feature_map_encoding.weight.data + self.encoder_proj.weight.data ) bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + nn.init.normal_(self.encoder_proj.bias, -bound, bound) @staticmethod def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) + args = encoder.args or {} network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) - return encoder_class(**encoder.args) + return encoder_class(**args) def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. @@ -121,6 +129,7 @@ class CNNTransformer(nn.Module): """ # Extract image features. image_features = self.encoder(image) + image_features = self.encoder_proj(image_features) # Add 2d encoding to the feature maps. image_features = self.feature_map_encoding(image_features) @@ -133,11 +142,19 @@ class CNNTransformer(nn.Module): """Decodes image features with transformer decoder.""" trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) + trg = rearrange(trg, "b t d -> t b d") trg = self.trg_position_encoding(trg) + trg = rearrange(trg, "t b d -> b t d") out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) logits = self.head(out) return logits + def forward(self, image: Tensor, trg: Tensor) -> Tensor: + image_features = self.encode(image) + output = self.decode(image_features, trg) + output = rearrange(output, "b t c -> b c t") + return output + def predict(self, image: Tensor) -> Tensor: """Transcribes text in image(s).""" bsz = image.shape[0] -- cgit v1.2.3-70-g09d2