diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-11 21:48:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-11 21:48:34 +0200 |
commit | 0ab820d3595e4f17d4f1f3c310e186692f65cc67 (patch) | |
tree | 21891ab98c10e64ef9261c69b2d494f42cda66f1 /text_recognizer/networks/image_transformer.py | |
parent | a548e421314908771ce9e413d9fa4e205943cceb (diff) |
Working on mapping
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 23 |
1 files changed, 6 insertions, 17 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 9ed67a4..daededa 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -10,16 +10,15 @@ TODO: Local attention for lower layer in attention. """ import importlib import math -from typing import Dict, List, Union, Sequence, Tuple, Type +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -import torchvision -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS from text_recognizer.networks.transformer import ( Decoder, DecoderLayer, @@ -28,6 +27,8 @@ from text_recognizer.networks.transformer import ( target_padding_mask, ) +NUM_WORD_PIECES = 1000 + class ImageTransformer(nn.Module): def __init__( @@ -35,7 +36,7 @@ class ImageTransformer(nn.Module): input_shape: Sequence[int], output_shape: Sequence[int], encoder: Union[DictConfig, Dict], - mapping: str, + vocab_size: Optional[int] = None, num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, @@ -43,14 +44,9 @@ class ImageTransformer(nn.Module): dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: - # Configure mapping - mapping, inverse_mapping = self._configure_mapping(mapping) - self.vocab_size = len(mapping) + self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size self.hidden_dim = hidden_dim self.max_output_length = output_shape[0] - self.start_index = inverse_mapping["<s>"] - self.end_index = inverse_mapping["<e>"] - self.pad_index = inverse_mapping["<p>"] # Image backbone self.encoder = self._configure_encoder(encoder) @@ -107,13 +103,6 @@ class ImageTransformer(nn.Module): encoder_class = getattr(network_module, encoder.type) return encoder_class(**encoder.args) - def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: - """Configures mapping.""" - # TODO: Fix me!!! - if mapping == "emnist": - mapping, inverse_mapping, _ = emnist_mapping() - return mapping, inverse_mapping - def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. |