summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/resize.py
blob: 769d08937f458b7138c76be2675b9b38ba2d25e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Up and down-sample with linear interpolation."""
from torch import nn, Tensor
import torch.nn.functional as F


class Upsample(nn.Module):
    """Upsamples by a factor 2."""

    def forward(self, x: Tensor) -> Tensor:
        """Applies upsampling."""
        return F.interpolate(x, scale_factor=2, mode="nearest")


class Downsample(nn.Module):
    """Downsampling by a factor 2."""

    def forward(self, x: Tensor) -> Tensor:
        """Applies downsampling."""
        return F.avg_pool2d(x, kernel_size=2, stride=2)