diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-12-04 17:00:58 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-12-04 17:00:58 +0100 |
commit | f0e006105b68a6e86a8c50f1a365fed0f65da460 (patch) | |
tree | cc3bab047aacb13e217eeb90d05d073ee8b3e1a1 /text_recognizer/networks/transformer/vit.py | |
parent | 3de4312d1796b1ee56d6467d36773df29a831e45 (diff) |
Revert "Remove ViT"
This reverts commit 3de4312d1796b1ee56d6467d36773df29a831e45.
Diffstat (limited to 'text_recognizer/networks/transformer/vit.py')
-rw-r--r-- | text_recognizer/networks/transformer/vit.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/vit.py b/text_recognizer/networks/transformer/vit.py new file mode 100644 index 0000000..ab331f8 --- /dev/null +++ b/text_recognizer/networks/transformer/vit.py @@ -0,0 +1,46 @@ +"""Vision Transformer.""" +from typing import Tuple, Type + +from einops.layers.torch import Rearrange +import torch +from torch import nn, Tensor + + +class ViT(nn.Module): + def __init__( + self, + image_size: Tuple[int, int], + patch_size: Tuple[int, int], + dim: int, + transformer: Type[nn.Module], + channels: int = 1, + ) -> None: + super().__init__() + img_height, img_width = image_size + patch_height, patch_width = patch_size + assert img_height % patch_height == 0 + assert img_width % patch_width == 0 + + num_patches = (img_height // patch_height) * (img_width // patch_width) + patch_dim = channels * patch_height * patch_width + + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=patch_height, + p2=patch_width, + c=channels, + ), + nn.Linear(patch_dim, dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) + self.transformer = transformer + self.norm = nn.LayerNorm(dim) + + def forward(self, img: Tensor) -> Tensor: + x = self.to_patch_embedding(img) + _, n, _ = x.shape + x += self.pos_embedding[:, :n] + x = self.transformer(x) + return x |