summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/transformer.py')
-rw-r--r--text_recognizer/networks/transformer/transformer.py16
1 files changed, 6 insertions, 10 deletions
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index dd180c4..5ac2787 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -41,7 +41,7 @@ class _IntraLayerConnection(nn.Module):
return self.norm(self.dropout(src) + residual)
-class _ConvolutionalLayer(nn.Module):
+class FeedForward(nn.Module):
def __init__(
self,
hidden_dim: int,
@@ -82,9 +82,7 @@ class EncoderLayer(nn.Module):
) -> None:
super().__init__()
self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
+ self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
@@ -99,10 +97,10 @@ class EncoderLayer(nn.Module):
# Second block.
# Apply 1D-convolution.
- cnn_out = self.cnn(out)
+ mlp_out = self.mlp(out)
# Add & norm.
- out = self.block2(cnn_out, out)
+ out = self.block2(mlp_out, out)
return out
@@ -148,9 +146,7 @@ class DecoderLayer(nn.Module):
self.multihead_attention = MultiHeadAttention(
hidden_dim, num_heads, dropout_rate
)
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
+ self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
@@ -169,7 +165,7 @@ class DecoderLayer(nn.Module):
out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
trg = self.block2(out, trg)
- out = self.cnn(trg)
+ out = self.mlp(trg)
out = self.block3(out, trg)
return out