summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/depth_wise_conv.py
blob: 1dbd0b879273680450b3a9505bcacc362e7382a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""Depthwise 1D convolution."""
from torch import nn, Tensor


class DepthwiseConv1D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            groups=in_channels,
            padding="same",
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.conv(x)