diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:16:21 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:16:21 +0200 |
commit | 4a6550ddef7d1f1971737bc22715db6381441f79 (patch) | |
tree | 33162ce381adf3b1b835df718aa3691ae87b8f4d /text_recognizer/networks | |
parent | b3d3b7ddc0796e98d78561bc5ca22728dc0372b0 (diff) |
Remove ViT
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/vit.py | 46 |
1 files changed, 0 insertions, 46 deletions
diff --git a/text_recognizer/networks/transformer/vit.py b/text_recognizer/networks/transformer/vit.py deleted file mode 100644 index ab331f8..0000000 --- a/text_recognizer/networks/transformer/vit.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Vision Transformer.""" -from typing import Tuple, Type - -from einops.layers.torch import Rearrange -import torch -from torch import nn, Tensor - - -class ViT(nn.Module): - def __init__( - self, - image_size: Tuple[int, int], - patch_size: Tuple[int, int], - dim: int, - transformer: Type[nn.Module], - channels: int = 1, - ) -> None: - super().__init__() - img_height, img_width = image_size - patch_height, patch_width = patch_size - assert img_height % patch_height == 0 - assert img_width % patch_width == 0 - - num_patches = (img_height // patch_height) * (img_width // patch_width) - patch_dim = channels * patch_height * patch_width - - self.to_patch_embedding = nn.Sequential( - Rearrange( - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=patch_height, - p2=patch_width, - c=channels, - ), - nn.Linear(patch_dim, dim), - ) - - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) - self.transformer = transformer - self.norm = nn.LayerNorm(dim) - - def forward(self, img: Tensor) -> Tensor: - x = self.to_patch_embedding(img) - _, n, _ = x.shape - x += self.pos_embedding[:, :n] - x = self.transformer(x) - return x |