From 7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 26 Apr 2021 22:04:47 +0200 Subject: Reformatting transformer (work in progress) --- text_recognizer/networks/transformer/transformer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) (limited to 'text_recognizer/networks/transformer/transformer.py') diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index dd180c4..5ac2787 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -41,7 +41,7 @@ class _IntraLayerConnection(nn.Module): return self.norm(self.dropout(src) + residual) -class _ConvolutionalLayer(nn.Module): +class FeedForward(nn.Module): def __init__( self, hidden_dim: int, @@ -82,9 +82,7 @@ class EncoderLayer(nn.Module): ) -> None: super().__init__() self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) + self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) @@ -99,10 +97,10 @@ class EncoderLayer(nn.Module): # Second block. # Apply 1D-convolution. - cnn_out = self.cnn(out) + mlp_out = self.mlp(out) # Add & norm. - out = self.block2(cnn_out, out) + out = self.block2(mlp_out, out) return out @@ -148,9 +146,7 @@ class DecoderLayer(nn.Module): self.multihead_attention = MultiHeadAttention( hidden_dim, num_heads, dropout_rate ) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) + self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) @@ -169,7 +165,7 @@ class DecoderLayer(nn.Module): out, _ = self.multihead_attention(trg, memory, memory, memory_mask) trg = self.block2(out, trg) - out = self.cnn(trg) + out = self.mlp(trg) out = self.block3(out, trg) return out -- cgit v1.2.3-70-g09d2