summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/resize.py
blob: 8d67d0212a409f9c2f1efcf64abd5d18e0fef3f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Up and down-sample with linear interpolation."""
from torch import nn, Tensor
import torch.nn.functional as F


class Upsample(nn.Module):
    """Upsamples by a factor 2."""

    def forward(self, x: Tensor) -> Tensor:
        """Applies upsampling."""
        return F.interpolate(x, scale_factor=2.0, mode="nearest")


class Downsample(nn.Module):
    """Downsampling by a factor 2."""

    def forward(self, x: Tensor) -> Tensor:
        """Applies downsampling."""
        return F.avg_pool2d(x, kernel_size=2, stride=2)