blob: 2ef4245da339977448c3a195a1539b7a1150e8c4 (
plain)
| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
 | """Conformer feedforward block."""
from torch import nn, Tensor
class Feedforward(nn.Module):
    def __init__(
        self, dim: int, expansion_factor: int = 4, dropout: float = 0.0
    ) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, expansion_factor * dim),
            nn.Mish(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(expansion_factor * dim, dim),
            nn.Dropout(dropout),
        )
    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)
 |