summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/coat/patch_embedding.py
blob: 3b7b76a05913fb5c630218e3bdfc2aa38715b2c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Patch embedding for images and feature maps."""
from typing import Sequence, Tuple

from einops import rearrange
from loguru import logger
from torch import nn
from torch import Tensor


class PatchEmbedding(nn.Module):
    """Patch embedding of images."""

    def __init__(
        self,
        image_shape: Sequence[int],
        patch_size: int = 16,
        in_channels: int = 1,
        embedding_dim: int = 512,
    ) -> None:
        if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0:
            logger.error(
                f"Image shape {image_shape} not divisable by patch size {patch_size}"
            )

        self.patch_size = patch_size
        self.embedding = nn.Conv2d(
            in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
        )
        self.norm = nn.LayerNorm(embedding_dim)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
        """Embeds image or feature maps with patch embedding."""
        _, _, h, w = x.shape
        h_out, w_out = h // self.patch_size, w // self.patch_size
        x = self.embedding(x)
        x = rearrange(x, "b c h w -> b (h w) c")
        x = self.norm(x)
        return x, (h_out, w_out)