diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
commit | b3fbfd72a8f647161685b28d20b4b61519d8a643 (patch) | |
tree | a5cac4e15186396aae35231d6d6fe266691b0186 /text_recognizer/network/transformer/vit.py | |
parent | c7e5354ffa43eccfc4e411375ce2f531af7bbcff (diff) |
Update transformer model
Diffstat (limited to 'text_recognizer/network/transformer/vit.py')
-rw-r--r-- | text_recognizer/network/transformer/vit.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/vit.py b/text_recognizer/network/transformer/vit.py new file mode 100644 index 0000000..3b600c3 --- /dev/null +++ b/text_recognizer/network/transformer/vit.py @@ -0,0 +1,64 @@ +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import Tensor, nn + +from .embedding.sincos import sincos_2d +from .encoder import Encoder + + +class PatchDropout(nn.Module): + def __init__(self, prob): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + b, n, _, device = *x.shape, x.device + + batch_indices = torch.arange(b, device = device) + batch_indices = rearrange(batch_indices, '... -> ... 1') + num_patches_keep = max(1, int(n * (1 - self.prob))) + patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices + + return x[batch_indices, patch_indices_keep] + + +class Vit(nn.Module): + def __init__( + self, + image_height: int, + image_width: int, + patch_height: int, + patch_width: int, + dim: int, + encoder: Encoder, + channels: int = 1, + patch_dropout: float = 0.0, + ) -> None: + super().__init__() + patch_dim = patch_height * patch_width * channels + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h ph) (w pw) -> b (h w) (ph pw c)", + ph=patch_height, + pw=patch_width, + ), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + self.patch_embedding = sincos_2d( + h=image_height // patch_height, w=image_width // patch_width, dim=dim + ) + self.encoder = encoder + self.patch_dropout = PatchDropout(patch_dropout) + + def forward(self, images: Tensor) -> Tensor: + x = self.to_patch_embedding(images) + x = x + self.patch_embedding.to(images.device, dtype=images.dtype) + x = self.patch_dropout(x) + return self.encoder(x) |