diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:14 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:14 +0200 |
commit | 49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch) | |
tree | 20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/networks/conv_transformer.py | |
parent | 0421daf6bd97596703f426ba61c401599b538eeb (diff) |
Rename and add flash atten
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 49 |
1 files changed, 0 insertions, 49 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py deleted file mode 100644 index d36162a..0000000 --- a/text_recognizer/networks/conv_transformer.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Base network module.""" -from typing import Type - -from torch import Tensor, nn - -from text_recognizer.networks.transformer.decoder import Decoder - - -class ConvTransformer(nn.Module): - """Base transformer network.""" - - def __init__( - self, - encoder: Type[nn.Module], - decoder: Decoder, - ) -> None: - super().__init__() - self.encoder = encoder - self.decoder = decoder - - def encode(self, img: Tensor) -> Tensor: - """Encodes images to latent representation.""" - return self.encoder(img) - - def decode(self, tokens: Tensor, img_features: Tensor) -> Tensor: - """Decodes latent images embedding into characters.""" - return self.decoder(tokens, img_features) - - def forward(self, img: Tensor, tokens: Tensor) -> Tensor: - """Encodes images into token logtis. - - Args: - img (Tensor): Input image(s). - tokens (Tensor): token embeddings. - - Shapes: - - img: :math: `(B, 1, H, W)` - - tokens: :math: `(B, Sy)` - - logits: :math: `(B, Sy, C)` - - where B is the batch size, H is the image height, W is the image - width, Sy the output length, and C is the number of classes. - - Returns: - Tensor: Sequence of logits. - """ - img_features = self.encode(img) - logits = self.decode(tokens, img_features) - return logits |