1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
"""Conformer convolutional block."""
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn, Tensor
from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D
from text_recognizer.networks.conformer.glu import GLU
class ConformerConv(nn.Module):
def __init__(
self,
dim: int,
expansion_factor: int = 2,
kernel_size: int = 31,
dropout: int = 0.0,
) -> None:
super().__init__()
inner_dim = expansion_factor * dim
self.layers = nn.Sequential(
nn.LayerNorm(dim),
Rearrange("b n c -> b c n"),
nn.Conv1D(dim, 2 * inner_dim, 1),
GLU(dim=1),
DepthwiseConv1D(inner_dim, inner_dim, kernel_size),
nn.BatchNorm1d(inner_dim),
nn.Mish(inplace=True),
nn.Conv1D(inner_dim, dim, 1),
Rearrange("b c n -> b n c"),
nn.Dropout(dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
|