diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-11 21:48:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-11 21:48:34 +0200 |
commit | 0ab820d3595e4f17d4f1f3c310e186692f65cc67 (patch) | |
tree | 21891ab98c10e64ef9261c69b2d494f42cda66f1 /text_recognizer | |
parent | a548e421314908771ce9e413d9fa4e205943cceb (diff) |
Working on mapping
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/mapping.py | 8 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 1 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 5 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 23 |
4 files changed, 17 insertions, 20 deletions
diff --git a/text_recognizer/data/mapping.py b/text_recognizer/data/mapping.py new file mode 100644 index 0000000..f0edf3f --- /dev/null +++ b/text_recognizer/data/mapping.py @@ -0,0 +1,8 @@ +"""Mapping to and from word pieces.""" +from pathlib import Path + + +class WordPieces: + + def __init__(self) -> None: + pass diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 0928e6c..c6d5d73 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -60,6 +60,7 @@ class LitBaseModel(pl.LightningModule): scheduler["scheduler"] = getattr( torch.optim.lr_scheduler, self._lr_scheduler.type )(optimizer, **args) + return scheduler def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index b23685b..7dc1352 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple, Type from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl @@ -19,7 +19,7 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network: Type[nn, Module], + network: Type[nn.Module], optimizer: Union[DictConfig, Dict], lr_scheduler: Union[DictConfig, Dict], criterion: Union[DictConfig, Dict], @@ -27,7 +27,6 @@ class LitTransformerModel(LitBaseModel): mapping: Optional[List[str]] = None, ) -> None: super().__init__(network, optimizer, lr_scheduler, criterion, monitor) - self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 9ed67a4..daededa 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -10,16 +10,15 @@ TODO: Local attention for lower layer in attention. """ import importlib import math -from typing import Dict, List, Union, Sequence, Tuple, Type +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -import torchvision -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS from text_recognizer.networks.transformer import ( Decoder, DecoderLayer, @@ -28,6 +27,8 @@ from text_recognizer.networks.transformer import ( target_padding_mask, ) +NUM_WORD_PIECES = 1000 + class ImageTransformer(nn.Module): def __init__( @@ -35,7 +36,7 @@ class ImageTransformer(nn.Module): input_shape: Sequence[int], output_shape: Sequence[int], encoder: Union[DictConfig, Dict], - mapping: str, + vocab_size: Optional[int] = None, num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, @@ -43,14 +44,9 @@ class ImageTransformer(nn.Module): dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: - # Configure mapping - mapping, inverse_mapping = self._configure_mapping(mapping) - self.vocab_size = len(mapping) + self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size self.hidden_dim = hidden_dim self.max_output_length = output_shape[0] - self.start_index = inverse_mapping["<s>"] - self.end_index = inverse_mapping["<e>"] - self.pad_index = inverse_mapping["<p>"] # Image backbone self.encoder = self._configure_encoder(encoder) @@ -107,13 +103,6 @@ class ImageTransformer(nn.Module): encoder_class = getattr(network_module, encoder.type) return encoder_class(**encoder.args) - def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: - """Configures mapping.""" - # TODO: Fix me!!! - if mapping == "emnist": - mapping, inverse_mapping, _ = emnist_mapping() - return mapping, inverse_mapping - def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. |