diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-10 00:32:18 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-10 00:32:18 +0200 |
commit | 3acfc51e91ba89bbdf7cef2392c274f4767e2cdf (patch) | |
tree | 5afc258f5afac667b91ecda5dc1a93687075d9ae /text_recognizer/networks/transformer/ff.py | |
parent | 14760428dd457f3749c6513ad34b822b05d6a742 (diff) |
Move mlp to ff
Diffstat (limited to 'text_recognizer/networks/transformer/ff.py')
-rw-r--r-- | text_recognizer/networks/transformer/ff.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/ff.py b/text_recognizer/networks/transformer/ff.py new file mode 100644 index 0000000..4028ab3 --- /dev/null +++ b/text_recognizer/networks/transformer/ff.py @@ -0,0 +1,46 @@ +"""Feedforward layer in transformer. + +Stolen from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py +""" +from typing import Optional + +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.fc = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + x, gate = self.fc(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + expansion_factor: int = 4, + glu: bool = True, + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + inner_dim = dim * expansion_factor + dim_out = dim_out if dim_out is not None else dim + in_projection = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.mlp = nn.Sequential( + in_projection, nn.Dropout(dropout_rate), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.mlp(x) |