summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/mlp.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:32:18 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:32:18 +0200
commit3acfc51e91ba89bbdf7cef2392c274f4767e2cdf (patch)
tree5afc258f5afac667b91ecda5dc1a93687075d9ae /text_recognizer/networks/transformer/mlp.py
parent14760428dd457f3749c6513ad34b822b05d6a742 (diff)
Move mlp to ff
Diffstat (limited to 'text_recognizer/networks/transformer/mlp.py')
-rw-r--r--text_recognizer/networks/transformer/mlp.py46
1 files changed, 0 insertions, 46 deletions
diff --git a/text_recognizer/networks/transformer/mlp.py b/text_recognizer/networks/transformer/mlp.py
deleted file mode 100644
index 4028ab3..0000000
--- a/text_recognizer/networks/transformer/mlp.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""Feedforward layer in transformer.
-
-Stolen from lucidrains:
- https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
-"""
-from typing import Optional
-
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-class GEGLU(nn.Module):
- def __init__(self, dim_in: int, dim_out: int) -> None:
- super().__init__()
- self.fc = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x: Tensor) -> Tensor:
- x, gate = self.fc(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- expansion_factor: int = 4,
- glu: bool = True,
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
- inner_dim = dim * expansion_factor
- dim_out = dim_out if dim_out is not None else dim
- in_projection = (
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
- if not glu
- else GEGLU(dim, inner_dim)
- )
-
- self.mlp = nn.Sequential(
- in_projection, nn.Dropout(dropout_rate), nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.mlp(x)