blob: 1dbd0b879273680450b3a9505bcacc362e7382a1 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
"""Depthwise 1D convolution."""
from torch import nn, Tensor
class DepthwiseConv1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
super().__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=in_channels,
padding="same",
)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)
|