blob: 9465b7caf420f712127f1f54a814406a62f1272c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
|
"""Depthwise 1D convolution."""
from torch import nn, Tensor
class DepthwiseConv1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
super().__init__()
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=in_channels,
padding="same")
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)
|