blob: 2ef4245da339977448c3a195a1539b7a1150e8c4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
"""Conformer feedforward block."""
from torch import nn, Tensor
class Feedforward(nn.Module):
def __init__(
self, dim: int, expansion_factor: int = 4, dropout: float = 0.0
) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, expansion_factor * dim),
nn.Mish(inplace=True),
nn.Dropout(dropout),
nn.Linear(expansion_factor * dim, dim),
nn.Dropout(dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
|