summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext/downsample.py
blob: c28eccaf05fc44605dbd988d77fb3ec8a6623610 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import Tuple

from einops.layers.torch import Rearrange
from torch import Tensor, nn


class Downsample(nn.Module):
    def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
        super().__init__()
        s1, s2 = factors
        self.fn = nn.Sequential(
            Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=s1, s2=s2),
            nn.Conv2d(dim * s1 * s2, dim_out, 1),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.fn(x)