diff options
Diffstat (limited to 'text_recognizer/networks/cnn_tranformer.py')
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index e030cb8..ce7ec43 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -7,6 +7,7 @@ import torch from torch import nn, Tensor from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, @@ -15,7 +16,7 @@ from text_recognizer.networks.transformer.positional_encodings import ( @attr.s -class CnnTransformer(nn.Module): +class Reader(nn.Module): def __attrs_pre_init__(self) -> None: super().__init__() @@ -27,21 +28,20 @@ class CnnTransformer(nn.Module): num_classes: int = attr.ib() padding_idx: int = attr.ib() start_token: str = attr.ib() - start_index: int = attr.ib(init=False, default=None) + start_index: int = attr.ib(init=False) end_token: str = attr.ib() - end_index: int = attr.ib(init=False, default=None) + end_index: int = attr.ib(init=False) pad_token: str = attr.ib() - pad_index: int = attr.ib(init=False, default=None) + pad_index: int = attr.ib(init=False) # Modules. - encoder: Type[nn.Module] = attr.ib() + encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() - embedding: nn.Embedding = attr.ib(init=False, default=None) - latent_encoder: nn.Sequential = attr.ib(init=False, default=None) - token_embedding: nn.Embedding = attr.ib(init=False, default=None) - token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None) - head: nn.Linear = attr.ib(init=False, default=None) - mapping: AbstractMapping = attr.ib(init=False, default=None) + latent_encoder: nn.Sequential = attr.ib(init=False) + token_embedding: nn.Embedding = attr.ib(init=False) + token_pos_encoder: PositionalEncoding = attr.ib(init=False) + head: nn.Linear = attr.ib(init=False) + mapping: Type[AbstractMapping] = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -187,12 +187,16 @@ class CnnTransformer(nn.Module): output[:, i : i + 1] = tokens[-1:] # Early stopping of prediction loop if token is end or padding token. - if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all(): + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): break # Set all tokens after end token to pad token. for i in range(1, self.max_output_len): - idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index) + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) output[idx, i] = self.pad_index return output |