diff options
Diffstat (limited to 'text_recognizer/networks/image_encoder.py')
-rw-r--r-- | text_recognizer/networks/image_encoder.py | 45 |
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/networks/image_encoder.py b/text_recognizer/networks/image_encoder.py deleted file mode 100644 index ab60560..0000000 --- a/text_recognizer/networks/image_encoder.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Encodes images to latent embeddings.""" -from typing import Tuple, Type - -from torch import Tensor, nn - -from text_recognizer.networks.transformer.embeddings.axial import ( - AxialPositionalEmbeddingImage, -) - - -class ImageEncoder(nn.Module): - """Encodes images to latent embeddings.""" - - def __init__( - self, - encoder: Type[nn.Module], - pixel_embedding: AxialPositionalEmbeddingImage, - ) -> None: - super().__init__() - self.encoder = encoder - self.pixel_embedding = pixel_embedding - - def forward(self, img: Tensor) -> Tensor: - """Encodes an image into a latent feature vector. - - Args: - img (Tensor): Image tensor. - - Shape: - - x: :math: `(B, C, H, W)` - - z: :math: `(B, Sx, D)` - - where Sx is the length of the flattened feature maps projected from - the encoder. D latent dimension for each pixel in the projected - feature maps. - - Returns: - Tensor: A Latent embedding of the image. - """ - z = self.encoder(img) - z = z + self.pixel_embedding(z) - z = z.flatten(start_dim=2) - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] - z = z.permute(0, 2, 1) - return z |