summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/vit.py46
1 files changed, 0 insertions, 46 deletions
diff --git a/text_recognizer/networks/transformer/vit.py b/text_recognizer/networks/transformer/vit.py
deleted file mode 100644
index ab331f8..0000000
--- a/text_recognizer/networks/transformer/vit.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""Vision Transformer."""
-from typing import Tuple, Type
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn, Tensor
-
-
-class ViT(nn.Module):
- def __init__(
- self,
- image_size: Tuple[int, int],
- patch_size: Tuple[int, int],
- dim: int,
- transformer: Type[nn.Module],
- channels: int = 1,
- ) -> None:
- super().__init__()
- img_height, img_width = image_size
- patch_height, patch_width = patch_size
- assert img_height % patch_height == 0
- assert img_width % patch_width == 0
-
- num_patches = (img_height // patch_height) * (img_width // patch_width)
- patch_dim = channels * patch_height * patch_width
-
- self.to_patch_embedding = nn.Sequential(
- Rearrange(
- "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
- p1=patch_height,
- p2=patch_width,
- c=channels,
- ),
- nn.Linear(patch_dim, dim),
- )
-
- self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
- self.transformer = transformer
- self.norm = nn.LayerNorm(dim)
-
- def forward(self, img: Tensor) -> Tensor:
- x = self.to_patch_embedding(img)
- _, n, _ = x.shape
- x += self.pos_embedding[:, :n]
- x = self.transformer(x)
- return x