summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext/downsample.py
blob: a8a046633ff53f8d5c11af3b30dc902b36b9c36b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Convnext downsample module."""
from typing import Tuple

from einops.layers.torch import Rearrange
from torch import Tensor, nn


class Downsample(nn.Module):
    """Downsamples feature maps by patches."""

    def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
        super().__init__()
        s1, s2 = factors
        self.fn = nn.Sequential(
            Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=s1, s2=s2),
            nn.Conv2d(dim * s1 * s2, dim_out, 1),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Applies patch function."""
        return self.fn(x)