diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
commit | ff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch) | |
tree | afee959135416fe92cf6df377e84fb0a9e9714a0 /src/text_recognizer/networks/transformer | |
parent | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff) |
Minor updates.
Diffstat (limited to 'src/text_recognizer/networks/transformer')
-rw-r--r-- | src/text_recognizer/networks/transformer/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/transformer/transformer.py | 26 |
2 files changed, 25 insertions, 3 deletions
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py index 020a917..9febc88 100644 --- a/src/text_recognizer/networks/transformer/__init__.py +++ b/src/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,3 @@ """Transformer modules.""" from .positional_encoding import PositionalEncoding -from .transformer import Decoder, Encoder, Transformer +from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py index c6e943e..dd180c4 100644 --- a/src/text_recognizer/networks/transformer/transformer.py +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -6,11 +6,25 @@ import numpy as np import torch from torch import nn from torch import Tensor +import torch.nn.functional as F from text_recognizer.networks.transformer.attention import MultiHeadAttention from text_recognizer.networks.util import activation_function +class GEGLU(nn.Module): + """GLU activation for improving feedforward activations.""" + + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + """Forward propagation.""" + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) @@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module): activation: str = "relu", ) -> None: super().__init__() + + in_projection = ( + nn.Sequential( + nn.Linear(hidden_dim, expansion_dim), activation_function(activation) + ) + if activation != "glu" + else GEGLU(hidden_dim, expansion_dim) + ) + self.layer = nn.Sequential( - nn.Linear(in_features=hidden_dim, out_features=expansion_dim), - activation_function(activation), + in_projection, nn.Dropout(p=dropout_rate), nn.Linear(in_features=expansion_dim, out_features=hidden_dim), ) |