from einops.layers.torch import Rearrange from torch import Tensor, nn from .transformer.embedding.sincos import sincos_2d from .transformer.encoder import Encoder class Vit(nn.Module): def __init__( self, image_height: int, image_width: int, patch_height: int, patch_width: int, dim: int, encoder: Encoder, channels: int = 1, ) -> None: super().__init__() patch_dim = patch_height * patch_width * channels self.to_patch_embedding = nn.Sequential( Rearrange( "b c (h ph) (w pw) -> b (h w) (ph pw c)", ph=patch_height, pw=patch_width, ), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) self.patch_embedding = sincos_2d( h=image_height // patch_height, w=image_width // patch_width, dim=dim ) self.encoder = encoder def forward(self, images: Tensor) -> Tensor: x = self.to_patch_embedding(images) x = x + self.patch_embedding.to(images.device, dtype=images.dtype) return self.encoder(x)