diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
commit | b3fbfd72a8f647161685b28d20b4b61519d8a643 (patch) | |
tree | a5cac4e15186396aae35231d6d6fe266691b0186 /text_recognizer/network/vit.py | |
parent | c7e5354ffa43eccfc4e411375ce2f531af7bbcff (diff) |
Update transformer model
Diffstat (limited to 'text_recognizer/network/vit.py')
-rw-r--r-- | text_recognizer/network/vit.py | 39 |
1 files changed, 0 insertions, 39 deletions
diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py deleted file mode 100644 index a596792..0000000 --- a/text_recognizer/network/vit.py +++ /dev/null @@ -1,39 +0,0 @@ -from einops.layers.torch import Rearrange -from torch import Tensor, nn - -from .transformer.embedding.sincos import sincos_2d -from .transformer.encoder import Encoder - - -class Vit(nn.Module): - def __init__( - self, - image_height: int, - image_width: int, - patch_height: int, - patch_width: int, - dim: int, - encoder: Encoder, - channels: int = 1, - ) -> None: - super().__init__() - patch_dim = patch_height * patch_width * channels - self.to_patch_embedding = nn.Sequential( - Rearrange( - "b c (h ph) (w pw) -> b (h w) (ph pw c)", - ph=patch_height, - pw=patch_width, - ), - nn.LayerNorm(patch_dim), - nn.Linear(patch_dim, dim), - nn.LayerNorm(dim), - ) - self.patch_embedding = sincos_2d( - h=image_height // patch_height, w=image_width // patch_width, dim=dim - ) - self.encoder = encoder - - def forward(self, images: Tensor) -> Tensor: - x = self.to_patch_embedding(images) - x = x + self.patch_embedding.to(images.device, dtype=images.dtype) - return self.encoder(x) |