blob: 8d67d0212a409f9c2f1efcf64abd5d18e0fef3f7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
"""Up and down-sample with linear interpolation."""
from torch import nn, Tensor
import torch.nn.functional as F
class Upsample(nn.Module):
"""Upsamples by a factor 2."""
def forward(self, x: Tensor) -> Tensor:
"""Applies upsampling."""
return F.interpolate(x, scale_factor=2.0, mode="nearest")
class Downsample(nn.Module):
"""Downsampling by a factor 2."""
def forward(self, x: Tensor) -> Tensor:
"""Applies downsampling."""
return F.avg_pool2d(x, kernel_size=2, stride=2)
|