diff options
Diffstat (limited to 'text_recognizer/networks/coat/patch_embedding.py')
-rw-r--r-- | text_recognizer/networks/coat/patch_embedding.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py new file mode 100644 index 0000000..3b7b76a --- /dev/null +++ b/text_recognizer/networks/coat/patch_embedding.py @@ -0,0 +1,38 @@ +"""Patch embedding for images and feature maps.""" +from typing import Sequence, Tuple + +from einops import rearrange +from loguru import logger +from torch import nn +from torch import Tensor + + +class PatchEmbedding(nn.Module): + """Patch embedding of images.""" + + def __init__( + self, + image_shape: Sequence[int], + patch_size: int = 16, + in_channels: int = 1, + embedding_dim: int = 512, + ) -> None: + if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: + logger.error( + f"Image shape {image_shape} not divisable by patch size {patch_size}" + ) + + self.patch_size = patch_size + self.embedding = nn.Conv2d( + in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: + """Embeds image or feature maps with patch embedding.""" + _, _, h, w = x.shape + h_out, w_out = h // self.patch_size, w // self.patch_size + x = self.embedding(x) + x = rearrange(x, "b c h w -> b (h w) c") + x = self.norm(x) + return x, (h_out, w_out) |