summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 20:10:54 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 20:10:54 +0100
commitff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch)
treeafee959135416fe92cf6df377e84fb0a9e9714a0 /src/text_recognizer/networks/transformer
parent25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff)
Minor updates.
Diffstat (limited to 'src/text_recognizer/networks/transformer')
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py2
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py26
2 files changed, 25 insertions, 3 deletions
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
index 020a917..9febc88 100644
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
"""Transformer modules."""
from .positional_encoding import PositionalEncoding
-from .transformer import Decoder, Encoder, Transformer
+from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index c6e943e..dd180c4 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -6,11 +6,25 @@ import numpy as np
import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
from text_recognizer.networks.transformer.attention import MultiHeadAttention
from text_recognizer.networks.util import activation_function
+class GEGLU(nn.Module):
+ """GLU activation for improving feedforward activations."""
+
+ def __init__(self, dim_in: int, dim_out: int) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward propagation."""
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
@@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module):
activation: str = "relu",
) -> None:
super().__init__()
+
+ in_projection = (
+ nn.Sequential(
+ nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+ )
+ if activation != "glu"
+ else GEGLU(hidden_dim, expansion_dim)
+ )
+
self.layer = nn.Sequential(
- nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
- activation_function(activation),
+ in_projection,
nn.Dropout(p=dropout_rate),
nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
)