blob: 5e2a7f4284e21e16659b5cd29fdc704e91c09f80 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
"""Simple convolutional network."""
import torch
from torch import nn, Tensor
class CNN(nn.Module):
def __init__(self, channels: int, depth: int) -> None:
super().__init__()
self.layers = self._build(channels, depth)
def _build(self, channels: int, depth: int) -> nn.Sequential:
layers = []
for i in range(depth):
layers.append(
nn.Conv2d(
in_channels=1 if i == 0 else channels,
out_channels=channels,
kernel_size=3,
stride=2,
)
)
layers.append(nn.Mish(inplace=True))
return nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
|