summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
commit49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch)
tree20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/networks/conv_transformer.py
parent0421daf6bd97596703f426ba61c401599b538eeb (diff)
Rename and add flash atten
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py49
1 files changed, 0 insertions, 49 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
deleted file mode 100644
index d36162a..0000000
--- a/text_recognizer/networks/conv_transformer.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Base network module."""
-from typing import Type
-
-from torch import Tensor, nn
-
-from text_recognizer.networks.transformer.decoder import Decoder
-
-
-class ConvTransformer(nn.Module):
- """Base transformer network."""
-
- def __init__(
- self,
- encoder: Type[nn.Module],
- decoder: Decoder,
- ) -> None:
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
-
- def encode(self, img: Tensor) -> Tensor:
- """Encodes images to latent representation."""
- return self.encoder(img)
-
- def decode(self, tokens: Tensor, img_features: Tensor) -> Tensor:
- """Decodes latent images embedding into characters."""
- return self.decoder(tokens, img_features)
-
- def forward(self, img: Tensor, tokens: Tensor) -> Tensor:
- """Encodes images into token logtis.
-
- Args:
- img (Tensor): Input image(s).
- tokens (Tensor): token embeddings.
-
- Shapes:
- - img: :math: `(B, 1, H, W)`
- - tokens: :math: `(B, Sy)`
- - logits: :math: `(B, Sy, C)`
-
- where B is the batch size, H is the image height, W is the image
- width, Sy the output length, and C is the number of classes.
-
- Returns:
- Tensor: Sequence of logits.
- """
- img_features = self.encode(img)
- logits = self.decode(tokens, img_features)
- return logits