diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/mlp.py | 46 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/norm.py | 22 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 16 |
3 files changed, 74 insertions, 10 deletions
diff --git a/text_recognizer/networks/transformer/mlp.py b/text_recognizer/networks/transformer/mlp.py new file mode 100644 index 0000000..4028ab3 --- /dev/null +++ b/text_recognizer/networks/transformer/mlp.py @@ -0,0 +1,46 @@ +"""Feedforward layer in transformer. + +Stolen from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py +""" +from typing import Optional + +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.fc = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + x, gate = self.fc(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + expansion_factor: int = 4, + glu: bool = True, + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + inner_dim = dim * expansion_factor + dim_out = dim_out if dim_out is not None else dim + in_projection = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.mlp = nn.Sequential( + in_projection, nn.Dropout(dropout_rate), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.mlp(x) diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py new file mode 100644 index 0000000..99a5291 --- /dev/null +++ b/text_recognizer/networks/transformer/norm.py @@ -0,0 +1,22 @@ +"""Normalization layers for transfromers. + +Copied from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +""" +from typing import Callable, Dict + +import torch +from torch import nn +from torch import Tensor + + +class Rezero(nn.Module): + def __init__(self, fn: Callable) -> None: + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index dd180c4..5ac2787 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -41,7 +41,7 @@ class _IntraLayerConnection(nn.Module): return self.norm(self.dropout(src) + residual) -class _ConvolutionalLayer(nn.Module): +class FeedForward(nn.Module): def __init__( self, hidden_dim: int, @@ -82,9 +82,7 @@ class EncoderLayer(nn.Module): ) -> None: super().__init__() self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) + self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) @@ -99,10 +97,10 @@ class EncoderLayer(nn.Module): # Second block. # Apply 1D-convolution. - cnn_out = self.cnn(out) + mlp_out = self.mlp(out) # Add & norm. - out = self.block2(cnn_out, out) + out = self.block2(mlp_out, out) return out @@ -148,9 +146,7 @@ class DecoderLayer(nn.Module): self.multihead_attention = MultiHeadAttention( hidden_dim, num_heads, dropout_rate ) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) + self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) @@ -169,7 +165,7 @@ class DecoderLayer(nn.Module): out, _ = self.multihead_attention(trg, memory, memory, memory_mask) trg = self.block2(out, trg) - out = self.cnn(trg) + out = self.mlp(trg) out = self.block3(out, trg) return out |