blob: 031bde991e6e66e0708719af301b57a125a54e2d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
"""Conformer feedforward block."""
from torch import nn, Tensor
class MLP(nn.Module):
def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, mult * dim),
nn.Mish(inplace=True),
nn.Dropout(dropout),
nn.Linear(mult * dim, dim),
nn.Dropout(dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
|