summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py42
1 files changed, 19 insertions, 23 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 09cc654..f3ba49d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -2,7 +2,6 @@
import math
from typing import Tuple
-import attr
from torch import nn, Tensor
from text_recognizer.networks.encoders.efficientnet import EfficientNet
@@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dim = hidden_dim
+ self.dropout_rate = dropout_rate
+ self.num_classes = num_classes
+ self.pad_index = pad_index
+ self.encoder = encoder
+ self.decoder = decoder
- # Parameters and placeholders,
- input_dims: Tuple[int, int, int] = attr.ib()
- hidden_dim: int = attr.ib()
- dropout_rate: float = attr.ib()
- max_output_len: int = attr.ib()
- num_classes: int = attr.ib()
- pad_index: Tensor = attr.ib()
-
- # Modules.
- encoder: EfficientNet = attr.ib()
- decoder: Decoder = attr.ib()
-
- latent_encoder: nn.Sequential = attr.ib(init=False)
- token_embedding: nn.Embedding = attr.ib(init=False)
- token_pos_encoder: PositionalEncoding = attr.ib(init=False)
- head: nn.Linear = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -126,7 +121,8 @@ class ConvTransformer(nn.Module):
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)
- logits = self.head(out)
+ logits = self.head(out) # [B, Sy, T]
+ logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits
def forward(self, x: Tensor, context: Tensor) -> Tensor: