From b3fbfd72a8f647161685b28d20b4b61519d8a643 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:49:51 +0200 Subject: Update transformer model --- text_recognizer/network/transformer/vit.py | 64 ++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 text_recognizer/network/transformer/vit.py (limited to 'text_recognizer/network/transformer/vit.py') diff --git a/text_recognizer/network/transformer/vit.py b/text_recognizer/network/transformer/vit.py new file mode 100644 index 0000000..3b600c3 --- /dev/null +++ b/text_recognizer/network/transformer/vit.py @@ -0,0 +1,64 @@ +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import Tensor, nn + +from .embedding.sincos import sincos_2d +from .encoder import Encoder + + +class PatchDropout(nn.Module): + def __init__(self, prob): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + b, n, _, device = *x.shape, x.device + + batch_indices = torch.arange(b, device = device) + batch_indices = rearrange(batch_indices, '... -> ... 1') + num_patches_keep = max(1, int(n * (1 - self.prob))) + patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices + + return x[batch_indices, patch_indices_keep] + + +class Vit(nn.Module): + def __init__( + self, + image_height: int, + image_width: int, + patch_height: int, + patch_width: int, + dim: int, + encoder: Encoder, + channels: int = 1, + patch_dropout: float = 0.0, + ) -> None: + super().__init__() + patch_dim = patch_height * patch_width * channels + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h ph) (w pw) -> b (h w) (ph pw c)", + ph=patch_height, + pw=patch_width, + ), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + self.patch_embedding = sincos_2d( + h=image_height // patch_height, w=image_width // patch_width, dim=dim + ) + self.encoder = encoder + self.patch_dropout = PatchDropout(patch_dropout) + + def forward(self, images: Tensor) -> Tensor: + x = self.to_patch_embedding(images) + x = x + self.patch_embedding.to(images.device, dtype=images.dtype) + x = self.patch_dropout(x) + return self.encoder(x) -- cgit v1.2.3-70-g09d2