blob: dcc14aa7fc52c92900347b2264c2e05016d2a2bf (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
"""Convnext downsample module."""
from einops.layers.torch import Rearrange
from torch import Tensor, nn
class Downsample(nn.Module):
"""Downsamples feature maps by patches."""
def __init__(self, dim: int, dim_out: int) -> None:
super().__init__()
self.fn = nn.Sequential(
Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=2, s2=2),
nn.Conv2d(dim * 4, dim_out, 1),
)
def forward(self, x: Tensor) -> Tensor:
"""Applies patch function."""
return self.fn(x)
|