diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 23 |
1 files changed, 6 insertions, 17 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 9ed67a4..daededa 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -10,16 +10,15 @@ TODO: Local attention for lower layer in attention. """ import importlib import math -from typing import Dict, List, Union, Sequence, Tuple, Type +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -import torchvision -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS from text_recognizer.networks.transformer import ( Decoder, DecoderLayer, @@ -28,6 +27,8 @@ from text_recognizer.networks.transformer import ( target_padding_mask, ) +NUM_WORD_PIECES = 1000 + class ImageTransformer(nn.Module): def __init__( @@ -35,7 +36,7 @@ class ImageTransformer(nn.Module): input_shape: Sequence[int], output_shape: Sequence[int], encoder: Union[DictConfig, Dict], - mapping: str, + vocab_size: Optional[int] = None, num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, @@ -43,14 +44,9 @@ class ImageTransformer(nn.Module): dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: - # Configure mapping - mapping, inverse_mapping = self._configure_mapping(mapping) - self.vocab_size = len(mapping) + self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size self.hidden_dim = hidden_dim self.max_output_length = output_shape[0] - self.start_index = inverse_mapping["<s>"] - self.end_index = inverse_mapping["<e>"] - self.pad_index = inverse_mapping["<p>"] # Image backbone self.encoder = self._configure_encoder(encoder) @@ -107,13 +103,6 @@ class ImageTransformer(nn.Module): encoder_class = getattr(network_module, encoder.type) return encoder_class(**encoder.args) - def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: - """Configures mapping.""" - # TODO: Fix me!!! - if mapping == "emnist": - mapping, inverse_mapping, _ = emnist_mapping() - return mapping, inverse_mapping - def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. |