summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/ff.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/ff.py')
-rw-r--r--text_recognizer/networks/transformer/ff.py45
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/networks/transformer/ff.py b/text_recognizer/networks/transformer/ff.py
deleted file mode 100644
index 3ccf5b5..0000000
--- a/text_recognizer/networks/transformer/ff.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""Feedforward layer in transformer.
-
-Stolen from lucidrains:
- https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
-"""
-from typing import Optional
-
-import torch.nn.functional as F
-from torch import Tensor, nn
-
-
-class GEGLU(nn.Module):
- def __init__(self, dim_in: int, dim_out: int) -> None:
- super().__init__()
- self.fc = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x: Tensor) -> Tensor:
- x, gate = self.fc(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- expansion_factor: int = 4,
- glu: bool = True,
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
- inner_dim = dim * expansion_factor
- dim_out = dim_out if dim_out is not None else dim
- in_projection = (
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
- if not glu
- else GEGLU(dim, inner_dim)
- )
-
- self.mlp = nn.Sequential(
- in_projection, nn.Dropout(dropout_rate), nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.mlp(x)