diff options
Diffstat (limited to 'text_recognizer/networks/coat/patch_embedding.py')
-rw-r--r-- | text_recognizer/networks/coat/patch_embedding.py | 38 |
1 files changed, 0 insertions, 38 deletions
diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py deleted file mode 100644 index 3b7b76a..0000000 --- a/text_recognizer/networks/coat/patch_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Patch embedding for images and feature maps.""" -from typing import Sequence, Tuple - -from einops import rearrange -from loguru import logger -from torch import nn -from torch import Tensor - - -class PatchEmbedding(nn.Module): - """Patch embedding of images.""" - - def __init__( - self, - image_shape: Sequence[int], - patch_size: int = 16, - in_channels: int = 1, - embedding_dim: int = 512, - ) -> None: - if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: - logger.error( - f"Image shape {image_shape} not divisable by patch size {patch_size}" - ) - - self.patch_size = patch_size - self.embedding = nn.Conv2d( - in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size - ) - self.norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: - """Embeds image or feature maps with patch embedding.""" - _, _, h, w = x.shape - h_out, w_out = h // self.patch_size, w // self.patch_size - x = self.embedding(x) - x = rearrange(x, "b c h w -> b (h w) c") - x = self.norm(x) - return x, (h_out, w_out) |