summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/depth_wise_conv.py
blob: 9465b7caf420f712127f1f54a814406a62f1272c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
"""Depthwise 1D convolution."""
from torch import nn, Tensor


class DepthwiseConv1D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
        super().__init__()
        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=in_channels,
                padding="same")

    def forward(self, x: Tensor) -> Tensor:
        return self.conv(x)