From 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 22 Apr 2021 08:15:58 +0200 Subject: Fixed training script, able to train vqvae --- text_recognizer/data/__init__.py | 3 + text_recognizer/data/emnist_lines.py | 2 +- text_recognizer/data/iam_extended_paragraphs.py | 15 +- text_recognizer/data/iam_paragraphs.py | 23 +- text_recognizer/data/iam_preprocessor.py | 1 + text_recognizer/data/iam_synthetic_paragraphs.py | 7 +- text_recognizer/data/mappings.py | 16 +- text_recognizer/data/transforms.py | 14 +- text_recognizer/models/__init__.py | 3 + text_recognizer/models/base.py | 9 + text_recognizer/models/vqvae.py | 70 ++++++ text_recognizer/networks/__init__.py | 2 +- text_recognizer/networks/cnn_transformer.py | 257 +++++++++++----------- text_recognizer/networks/image_transformer.py | 165 -------------- text_recognizer/networks/residual_network.py | 6 +- text_recognizer/networks/transducer/transducer.py | 7 +- text_recognizer/networks/vqvae/decoder.py | 20 +- text_recognizer/networks/vqvae/encoder.py | 30 ++- text_recognizer/networks/vqvae/vqvae.py | 5 +- 19 files changed, 318 insertions(+), 337 deletions(-) create mode 100644 text_recognizer/models/vqvae.py delete mode 100644 text_recognizer/networks/image_transformer.py (limited to 'text_recognizer') diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 9a42fa9..3599a8b 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -2,3 +2,6 @@ from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset from .base_data_module import BaseDataModule, load_and_print_info from .download_utils import download_dataset +from .iam_paragraphs import IAMParagraphs +from .iam_synthetic_paragraphs import IAMSyntheticParagraphs +from .iam_extended_paragraphs import IAMExtendedParagraphs diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 72665d0..9650198 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -57,8 +57,8 @@ class EMNISTLines(BaseDataModule): self.num_test = num_test self.emnist = EMNIST() - # TODO: fix mapping self.mapping = self.emnist.mapping + max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index d2529b4..2380660 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -10,18 +10,27 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs class IAMExtendedParagraphs(BaseDataModule): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, + word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, + word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index f588587..62c44f9 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from loguru import logger import numpy as np -from PIL import Image, ImageFile, ImageOps -import torch +from PIL import Image, ImageOps import torchvision.transforms as transforms from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm @@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import ( from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam import IAM +from text_recognizer.data.transforms import WordPiece PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" @@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: super().__init__(batch_size, num_workers) - # TODO: pass in transform and target transform - # TODO: pass in mapping self.augment = augment + self.word_pieces = word_pieces self.mapping, self.inverse_mapping, _ = emnist_mapping( extra_symbols=[NEW_LINE_TOKEN] ) @@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule): data, targets, transform=get_transform(image_shape=self.dims[1:], augment=augment), + target_transform=get_target_transform(self.word_pieces) ) logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: @@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com transforms_list.append(transforms.ToTensor()) return transforms.Compose(transforms_list) +def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: + """Transform emnist characters to word pieces.""" + return transforms.Compose([WordPiece()]) if word_pieces else None def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 60f8a9f..b5f72da 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -89,6 +89,7 @@ class Preprocessor: self.lexicon = None if self.special_tokens is not None: + self.special_tokens += ("#", "*") self.tokens += self.special_tokens self.graphemes += self.special_tokens diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 9f1bd12..4ccc5c2 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -18,6 +18,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.iam_paragraphs import ( get_dataset_properties, get_transform, + get_target_transform, NEW_LINE_TOKEN, IAMParagraphs, IMAGE_SCALE_FACTOR, @@ -41,12 +42,13 @@ class IAMSyntheticParagraphs(IAMParagraphs): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: - super().__init__(batch_size, num_workers, train_fraction, augment) + super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces) def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" @@ -95,6 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): transform=get_transform( image_shape=self.dims[1:], augment=self.augment ), + target_transform=get_target_transform(self.word_pieces) ) def __repr__(self) -> str: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index cfa0ec7..f4016ba 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -8,7 +8,7 @@ import torch from torch import Tensor from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.data.iam_preprocessor import Preprocessor class AbstractMapping(ABC): @@ -57,14 +57,14 @@ class EmnistMapping(AbstractMapping): class WordPieceMapping(EmnistMapping): def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("", "", "

"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n", ), ) -> None: super().__init__(extra_symbols) self.wordpiece_processor = self._configure_wordpiece_processor( @@ -78,8 +78,8 @@ class WordPieceMapping(EmnistMapping): extra_symbols, ) + @staticmethod def _configure_wordpiece_processor( - self, num_features: int, tokens: str, lexicon: str, @@ -90,7 +90,7 @@ class WordPieceMapping(EmnistMapping): extra_symbols: Optional[Sequence[str]], ) -> Preprocessor: data_dir = ( - (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb") + (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb") if data_dir is None else Path(data_dir) ) @@ -138,6 +138,6 @@ class WordPieceMapping(EmnistMapping): return self.wordpiece_processor.to_index(text) def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: - text = self.mapping.get_text(x) + text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index f53df64..8d1bedd 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence from torch import Tensor -from text_recognizer.datasets.mappings import WordPieceMapping +from text_recognizer.data.mappings import WordPieceMapping class WordPiece: @@ -12,14 +12,15 @@ class WordPiece: def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("", "", "

"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n",), + max_len: int = 192, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -31,6 +32,7 @@ class WordPiece: special_tokens, extra_symbols, ) + self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x) + return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index e69de29..5ac2510 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -0,0 +1,3 @@ +"""PyTorch Lightning models modules.""" +from .transformer import LitTransformerModel +from .vqvae import LitVQVAEModel diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index aeda039..88ffde6 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -40,6 +40,15 @@ class LitBaseModel(pl.LightningModule): args = {} or criterion.args return getattr(nn, criterion.type)(**args) + def optimizer_zero_grad( + self, + epoch: int, + batch_idx: int, + optimizer: Type[torch.optim.Optimizer], + optimizer_idx: int, + ) -> None: + optimizer.zero_grad(set_to_none=True) + def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: """Configures the optimizer.""" args = {} or self._optimizer.args diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py new file mode 100644 index 0000000..ef2213c --- /dev/null +++ b/text_recognizer/models/vqvae.py @@ -0,0 +1,70 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Any, Dict, Union, Tuple, Type + +from omegaconf import DictConfig, OmegaConf +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import wandb + +from text_recognizer.models.base import LitBaseModel + + +class LitVQVAEModel(LitBaseModel): + """A PyTorch Lightning model for transformer networks.""" + + def __init__( + self, + network: Type[nn.Module], + optimizer: Union[DictConfig, Dict], + lr_scheduler: Union[DictConfig, Dict], + criterion: Union[DictConfig, Dict], + monitor: str = "val_loss", + *args: Any, + **kwargs: Dict, + ) -> None: + super().__init__(network, optimizer, lr_scheduler, criterion, monitor) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.network.predict(data) + + def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None: + """Logs prediction on image with wandb.""" + try: + self.logger.experiment.log( + { + "val_pred_examples": [ + wandb.Image(data[0]), + wandb.Image(reconstructions[0]), + ] + } + ) + except AttributeError: + pass + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, _ = batch + reconstructions, vq_loss = self.network(data) + loss = self.loss_fn(reconstructions, data) + loss += vq_loss + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, _ = batch + reconstructions, vq_loss = self.network(data) + loss = self.loss_fn(reconstructions, data) + loss += vq_loss + self.log("val_loss", loss, prog_bar=True) + self._log_prediction(data, reconstructions) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, _ = batch + reconstructions, vq_loss = self.network(data) + loss = self.loss_fn(reconstructions, data) + loss += vq_loss + self._log_prediction(data, reconstructions) diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 979149f..41fd43f 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,2 +1,2 @@ """Network modules""" -from .image_transformer import ImageTransformer +from .vqvae import VQVAE diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index 9150b55..e23a15d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,158 +1,165 @@ -"""A CNN-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple +"""A Transformer with a cnn backbone. + +The network encodes a image with a convolutional backbone to a latent representation, +i.e. feature maps. A 2d positional encoding is applied to the feature maps for +spatial information. The resulting feature are then set to a transformer decoder +together with the target tokens. + +TODO: Local attention for lower layer in attention. + +""" +import importlib +import math +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +from text_recognizer.networks.transformer import ( + Decoder, + DecoderLayer, + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) +NUM_WORD_PIECES = 1000 -class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction.""" +class CNNTransformer(nn.Module): def __init__( self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - pool_kernel: Optional[Tuple[int, int]] = None, + input_shape: Sequence[int], + output_shape: Sequence[int], + encoder: Union[DictConfig, Dict], + vocab_size: Optional[int] = None, + num_decoder_layers: int = 4, + hidden_dim: int = 256, + num_heads: int = 4, + expansion_dim: int = 1024, + dropout_rate: float = 0.1, + transformer_activation: str = "glu", ) -> None: - super().__init__() - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.backbone = configure_backbone(backbone, backbone_args) - - if pool_kernel is not None: - self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) - else: - self.max_pool = None - - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.pos_dropout = nn.Dropout(p=dropout_rate) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + self.vocab_size = ( + NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size ) + self.hidden_dim = hidden_dim + self.max_output_length = output_shape[0] - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, + # Image backbone + self.encoder = self._configure_encoder(encoder) + self.feature_map_encoding = PositionalEncoding2D( + hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] ) - self.head = nn.Sequential( - # nn.Linear(hidden_dim, hidden_dim * 2), - # activation_function(activation), - nn.Linear(hidden_dim, vocab_size), - ) + # Target token embedding + self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + # Transformer decoder + self.decoder = Decoder( + decoder_layer=DecoderLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion_dim=expansion_dim, + dropout_rate=dropout_rate, + activation=transformer_activation, + ), + num_layers=num_decoder_layers, + norm=nn.LayerNorm(hidden_dim), ) - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + # Classification head + self.head = nn.Linear(hidden_dim, self.vocab_size) - Args: - src (Tensor): Input tensor. + # Initialize weights + self._init_weights() - Returns: - Tensor: A input src to the transformer. + def _init_weights(self) -> None: + """Initialize network weights.""" + self.trg_embedding.weight.data.uniform_(-0.1, 0.1) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-0.1, 0.1) - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - src = self.backbone(src) - - if self.max_pool is not None: - src = self.max_pool(src) - - if self.adaptive_pool is not None and len(src.shape) == 4: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - elif len(src.shape) == 4: - src = rearrange(src, "b c h w -> b (h w) c") + nn.init.kaiming_normal_( + self.feature_map_encoding.weight.data, + a=0, + mode="fan_out", + nonlinearity="relu", + ) + if self.feature_map_encoding.bias is not None: + _, fan_out = nn.init._calculate_fan_in_and_fan_out( + self.feature_map_encoding.weight.data + ) + bound = 1 / math.sqrt(fan_out) + nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + + @staticmethod + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: + encoder = OmegaConf.create(encoder) + network_module = importlib.import_module("text_recognizer.networks") + encoder_class = getattr(network_module, encoder.type) + return encoder_class(**encoder.args) + + def encode(self, image: Tensor) -> Tensor: + """Extracts image features with backbone. - b, t, _ = src.shape + Args: + image (Tensor): Image(s) of handwritten text. - src += self.src_position_embedding[:, :t] - src = self.pos_dropout(src) + Retuns: + Tensor: Image features. - return src + Shapes: + - image: :math: `(B, C, H, W)` + - latent: :math: `(B, T, C)` - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. + """ + # Extract image features. + image_features = self.encoder(image) - Args: - trg (Tensor): Target tensor. + # Add 2d encoding to the feature maps. + image_features = self.feature_map_encoding(image_features) - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + # Collapse features maps height and width. + image_features = rearrange(image_features, "b c h w -> b (h w) c") + return image_features - """ - trg = self.character_embedding(trg.long()) + def decode(self, memory: Tensor, trg: Tensor) -> Tensor: + """Decodes image features with transformer decoder.""" + trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) + trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - + out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) logits = self.head(out) return logits - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits + def predict(self, image: Tensor) -> Tensor: + """Transcribes text in image(s).""" + bsz = image.shape[0] + image_features = self.encode(image) + + output_tokens = ( + (torch.ones((bsz, self.max_output_length)) * self.pad_index) + .type_as(image) + .long() + ) + output_tokens[:, 0] = self.start_index + for i in range(1, self.max_output_length): + trg = output_tokens[:, :i] + output = self.decode(image_features, trg) + output = torch.argmax(output, dim=-1) + output_tokens[:, i] = output[-1:] + + # Set all tokens after end token to be padding. + for i in range(1, self.max_output_length): + indices = output_tokens[:, i - 1] == self.end_index | ( + output_tokens[:, i - 1] == self.pad_index + ) + output_tokens[indices, i] = self.pad_index + + return output_tokens diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py deleted file mode 100644 index a6aaca4..0000000 --- a/text_recognizer/networks/image_transformer.py +++ /dev/null @@ -1,165 +0,0 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( - Decoder, - DecoderLayer, - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class ImageTransformer(nn.Module): - def __init__( - self, - input_shape: Sequence[int], - output_shape: Sequence[int], - encoder: Union[DictConfig, Dict], - vocab_size: Optional[int] = None, - num_decoder_layers: int = 4, - hidden_dim: int = 256, - num_heads: int = 4, - expansion_dim: int = 1024, - dropout_rate: float = 0.1, - transformer_activation: str = "glu", - ) -> None: - self.vocab_size = ( - NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size - ) - self.hidden_dim = hidden_dim - self.max_output_length = output_shape[0] - - # Image backbone - self.encoder = self._configure_encoder(encoder) - self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] - ) - - # Target token embedding - self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - # Transformer decoder - self.decoder = Decoder( - decoder_layer=DecoderLayer( - hidden_dim=hidden_dim, - num_heads=num_heads, - expansion_dim=expansion_dim, - dropout_rate=dropout_rate, - activation=transformer_activation, - ), - num_layers=num_decoder_layers, - norm=nn.LayerNorm(hidden_dim), - ) - - # Classification head - self.head = nn.Linear(hidden_dim, self.vocab_size) - - # Initialize weights - self._init_weights() - - def _init_weights(self) -> None: - """Initialize network weights.""" - self.trg_embedding.weight.data.uniform_(-0.1, 0.1) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-0.1, 0.1) - - nn.init.kaiming_normal_( - self.feature_map_encoding.weight.data, - a=0, - mode="fan_out", - nonlinearity="relu", - ) - if self.feature_map_encoding.bias is not None: - _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.feature_map_encoding.weight.data - ) - bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) - - @staticmethod - def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: - encoder = OmegaConf.create(encoder) - network_module = importlib.import_module("text_recognizer.networks") - encoder_class = getattr(network_module, encoder.type) - return encoder_class(**encoder.args) - - def encode(self, image: Tensor) -> Tensor: - """Extracts image features with backbone. - - Args: - image (Tensor): Image(s) of handwritten text. - - Retuns: - Tensor: Image features. - - Shapes: - - image: :math: `(B, C, H, W)` - - latent: :math: `(B, T, C)` - - """ - # Extract image features. - image_features = self.encoder(image) - - # Add 2d encoding to the feature maps. - image_features = self.feature_map_encoding(image_features) - - # Collapse features maps height and width. - image_features = rearrange(image_features, "b c h w -> b (h w) c") - return image_features - - def decode(self, memory: Tensor, trg: Tensor) -> Tensor: - """Decodes image features with transformer decoder.""" - trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) - trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.trg_position_encoding(trg) - out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) - logits = self.head(out) - return logits - - def predict(self, image: Tensor) -> Tensor: - """Transcribes text in image(s).""" - bsz = image.shape[0] - image_features = self.encode(image) - - output_tokens = ( - (torch.ones((bsz, self.max_output_length)) * self.pad_index) - .type_as(image) - .long() - ) - output_tokens[:, 0] = self.start_index - for i in range(1, self.max_output_length): - trg = output_tokens[:, :i] - output = self.decode(image_features, trg) - output = torch.argmax(output, dim=-1) - output_tokens[:, i] = output[-1:] - - # Set all tokens after end token to be padding. - for i in range(1, self.max_output_length): - indices = output_tokens[:, i - 1] == self.end_index | ( - output_tokens[:, i - 1] == self.pad_index - ) - output_tokens[indices, i] = self.pad_index - - return output_tokens diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index c33f419..da7553d 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + conv3x3 = partial( + Conv2dAuto, + kernel_size=3, + bias=False, + ) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index d7e3d08..b10f93a 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,7 +392,12 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 8847aba..93a1e43 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,7 +44,12 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, ) def _build_decompression_block( @@ -72,8 +77,10 @@ class Decoder(nn.Module): ) ) - if i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) + if self.upsampling and i < len(self.upsampling): + modules.append( + nn.Upsample(size=self.upsampling[i]), + ) if dropout is not None: modules.append(dropout) @@ -102,7 +109,12 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + nn.Conv2d( + self.embedding_dim, + channels[0], + kernel_size=1, + stride=1, + ) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index d3adac5..b0cceed 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,5 +1,5 @@ """CNN encoder for the VQ-VAE.""" -from typing import List, Optional, Tuple, Type +from typing import Sequence, Optional, Tuple, Type import torch from torch import nn @@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], + self, + in_channels: int, + out_channels: int, + dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -36,9 +39,9 @@ class Encoder(nn.Module): def __init__( self, in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], + channels: Sequence[int], + kernel_sizes: Sequence[int], + strides: Sequence[int], num_residual_layers: int, embedding_dim: int, num_embeddings: int, @@ -77,12 +80,12 @@ class Encoder(nn.Module): self.num_embeddings, self.embedding_dim, self.beta ) + @staticmethod def _build_compression_block( - self, in_channels: int, channels: int, - kernel_sizes: List[int], - strides: List[int], + kernel_sizes: Sequence[int], + strides: Sequence[int], activation: Type[nn.Module], dropout: Optional[Type[nn.Module]], ) -> nn.ModuleList: @@ -109,8 +112,8 @@ class Encoder(nn.Module): self, in_channels: int, channels: int, - kernel_sizes: List[int], - strides: List[int], + kernel_sizes: Sequence[int], + strides: Sequence[int], num_residual_layers: int, activation: Type[nn.Module], dropout: Optional[Type[nn.Module]], @@ -135,7 +138,12 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + nn.Conv2d( + channels[-1], + self.embedding_dim, + kernel_size=1, + stride=1, + ) ) return nn.Sequential(*encoder) diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 50448b4..1f08e5e 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,8 +1,7 @@ """The VQ-VAE.""" -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple -import torch from torch import nn from torch import Tensor @@ -25,6 +24,8 @@ class VQVAE(nn.Module): beta: float = 0.25, activation: str = "leaky_relu", dropout_rate: float = 0.0, + *args: Any, + **kwargs: Dict, ) -> None: super().__init__() -- cgit v1.2.3-70-g09d2