summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transformer/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/transformer/transformer.py')
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index c6e943e..dd180c4 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -6,11 +6,25 @@ import numpy as np
import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
from text_recognizer.networks.transformer.attention import MultiHeadAttention
from text_recognizer.networks.util import activation_function
+class GEGLU(nn.Module):
+ """GLU activation for improving feedforward activations."""
+
+ def __init__(self, dim_in: int, dim_out: int) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward propagation."""
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
@@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module):
activation: str = "relu",
) -> None:
super().__init__()
+
+ in_projection = (
+ nn.Sequential(
+ nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+ )
+ if activation != "glu"
+ else GEGLU(hidden_dim, expansion_dim)
+ )
+
self.layer = nn.Sequential(
- nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
- activation_function(activation),
+ in_projection,
nn.Dropout(p=dropout_rate),
nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
)