blob: 769d08937f458b7138c76be2675b9b38ba2d25e5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
"""Up and down-sample with linear interpolation."""
from torch import nn, Tensor
import torch.nn.functional as F
class Upsample(nn.Module):
"""Upsamples by a factor 2."""
def forward(self, x: Tensor) -> Tensor:
"""Applies upsampling."""
return F.interpolate(x, scale_factor=2, mode="nearest")
class Downsample(nn.Module):
"""Downsampling by a factor 2."""
def forward(self, x: Tensor) -> Tensor:
"""Applies downsampling."""
return F.avg_pool2d(x, kernel_size=2, stride=2)
|