1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
"""Patch embedding for images and feature maps."""
from typing import Sequence, Tuple
from einops import rearrange
from loguru import logger
from torch import nn
from torch import Tensor
class PatchEmbedding(nn.Module):
"""Patch embedding of images."""
def __init__(
self,
image_shape: Sequence[int],
patch_size: int = 16,
in_channels: int = 1,
embedding_dim: int = 512,
) -> None:
if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0:
logger.error(
f"Image shape {image_shape} not divisable by patch size {patch_size}"
)
self.patch_size = patch_size
self.embedding = nn.Conv2d(
in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
"""Embeds image or feature maps with patch embedding."""
_, _, h, w = x.shape
h_out, w_out = h // self.patch_size, w // self.patch_size
x = self.embedding(x)
x = rearrange(x, "b c h w -> b (h w) c")
x = self.norm(x)
return x, (h_out, w_out)
|