summaryrefslogtreecommitdiff
path: root/text_recognizer/network/convnext/downsample.py
blob: dcc14aa7fc52c92900347b2264c2e05016d2a2bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Convnext downsample module."""
from einops.layers.torch import Rearrange
from torch import Tensor, nn


class Downsample(nn.Module):
    """Downsamples feature maps by patches."""

    def __init__(self, dim: int, dim_out: int) -> None:
        super().__init__()
        self.fn = nn.Sequential(
            Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=2, s2=2),
            nn.Conv2d(dim * 4, dim_out, 1),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Applies patch function."""
        return self.fn(x)