blob: 9181323f2dc8f16f8eec1ec6e5c06ab5aaf2b8f5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
"""Feedforward layer in transformer."""
from torch import Tensor, nn
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: int,
dropout_rate: float = 0.0,
) -> None:
super().__init__()
self.ff = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(inner_dim, dim),
)
def forward(self, x: Tensor) -> Tensor:
return self.ff(x)
|