diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/data/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/data/emnist_lines.py | 2 | ||||
| -rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 15 | ||||
| -rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 23 | ||||
| -rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 1 | ||||
| -rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 7 | ||||
| -rw-r--r-- | text_recognizer/data/mappings.py | 16 | ||||
| -rw-r--r-- | text_recognizer/data/transforms.py | 14 | ||||
| -rw-r--r-- | text_recognizer/models/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/models/base.py | 9 | ||||
| -rw-r--r-- | text_recognizer/models/vqvae.py | 70 | ||||
| -rw-r--r-- | text_recognizer/networks/__init__.py | 2 | ||||
| -rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 257 | ||||
| -rw-r--r-- | text_recognizer/networks/image_transformer.py | 165 | ||||
| -rw-r--r-- | text_recognizer/networks/residual_network.py | 6 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 7 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 20 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 30 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 5 | 
19 files changed, 318 insertions, 337 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 9a42fa9..3599a8b 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -2,3 +2,6 @@  from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset  from .base_data_module import BaseDataModule, load_and_print_info  from .download_utils import download_dataset +from .iam_paragraphs import IAMParagraphs +from .iam_synthetic_paragraphs import IAMSyntheticParagraphs +from .iam_extended_paragraphs import IAMExtendedParagraphs diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 72665d0..9650198 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -57,8 +57,8 @@ class EMNISTLines(BaseDataModule):          self.num_test = num_test          self.emnist = EMNIST() -        # TODO: fix mapping          self.mapping = self.emnist.mapping +          max_width = (              int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))              + IMAGE_X_PADDING diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index d2529b4..2380660 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -10,18 +10,27 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs  class IAMExtendedParagraphs(BaseDataModule):      def __init__(          self, -        batch_size: int = 128, +        batch_size: int = 16,          num_workers: int = 0,          train_fraction: float = 0.8,          augment: bool = True, +        word_pieces: bool = False,      ) -> None:          super().__init__(batch_size, num_workers)          self.iam_paragraphs = IAMParagraphs( -            batch_size, num_workers, train_fraction, augment, +            batch_size, +            num_workers, +            train_fraction, +            augment, +            word_pieces,          )          self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( -            batch_size, num_workers, train_fraction, augment, +            batch_size, +            num_workers, +            train_fraction, +            augment, +            word_pieces,          )          self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index f588587..62c44f9 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple  from loguru import logger  import numpy as np -from PIL import Image, ImageFile, ImageOps -import torch +from PIL import Image, ImageOps  import torchvision.transforms as transforms  from torchvision.transforms.functional import InterpolationMode  from tqdm import tqdm @@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import (  from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info  from text_recognizer.data.emnist import emnist_mapping  from text_recognizer.data.iam import IAM +from text_recognizer.data.transforms import WordPiece  PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" @@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule):      def __init__(          self, -        batch_size: int = 128, +        batch_size: int = 16,          num_workers: int = 0,          train_fraction: float = 0.8,          augment: bool = True, +        word_pieces: bool = False,      ) -> None:          super().__init__(batch_size, num_workers) -        # TODO: pass in transform and target transform -        # TODO: pass in mapping          self.augment = augment +        self.word_pieces = word_pieces          self.mapping, self.inverse_mapping, _ = emnist_mapping(              extra_symbols=[NEW_LINE_TOKEN]          ) @@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule):                  data,                  targets,                  transform=get_transform(image_shape=self.dims[1:], augment=augment), +                target_transform=get_target_transform(self.word_pieces)              )          logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict:              "min": min(_get_property_values("num_lines")),              "max": max(_get_property_values("num_lines")),          }, -        "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, +        "crop_shape": { +            "min": crop_shapes.min(axis=0), +            "max": crop_shapes.max(axis=0), +        },          "aspect_ratio": {              "min": aspect_ratio.min(axis=0),              "max": aspect_ratio.max(axis=0), @@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com              ),              transforms.ColorJitter(brightness=(0.8, 1.6)),              transforms.RandomAffine( -                degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, +                degrees=1, +                shear=(-10, 10), +                interpolation=InterpolationMode.BILINEAR,              ),          ]      else: @@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com      transforms_list.append(transforms.ToTensor())      return transforms.Compose(transforms_list) +def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: +    """Transform emnist characters to word pieces.""" +    return transforms.Compose([WordPiece()]) if word_pieces else None  def _labels_filename(split: str) -> Path:      """Return filename of processed labels.""" diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 60f8a9f..b5f72da 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -89,6 +89,7 @@ class Preprocessor:              self.lexicon = None          if self.special_tokens is not None: +            self.special_tokens += ("#", "*")              self.tokens += self.special_tokens              self.graphemes += self.special_tokens diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 9f1bd12..4ccc5c2 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -18,6 +18,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print  from text_recognizer.data.iam_paragraphs import (      get_dataset_properties,      get_transform, +    get_target_transform,      NEW_LINE_TOKEN,      IAMParagraphs,      IMAGE_SCALE_FACTOR, @@ -41,12 +42,13 @@ class IAMSyntheticParagraphs(IAMParagraphs):      def __init__(          self, -        batch_size: int = 128, +        batch_size: int = 16,          num_workers: int = 0,          train_fraction: float = 0.8,          augment: bool = True, +        word_pieces: bool = False,      ) -> None: -        super().__init__(batch_size, num_workers, train_fraction, augment) +        super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces)      def prepare_data(self) -> None:          """Prepare IAM lines to be used to generate paragraphs.""" @@ -95,6 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):                  transform=get_transform(                      image_shape=self.dims[1:], augment=self.augment                  ), +                target_transform=get_target_transform(self.word_pieces)              )      def __repr__(self) -> str: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index cfa0ec7..f4016ba 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -8,7 +8,7 @@ import torch  from torch import Tensor  from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.data.iam_preprocessor import Preprocessor  class AbstractMapping(ABC): @@ -57,14 +57,14 @@ class EmnistMapping(AbstractMapping):  class WordPieceMapping(EmnistMapping):      def __init__(          self, -        num_features: int, -        tokens: str, -        lexicon: str, +        num_features: int = 1000, +        tokens: str = "iamdb_1kwp_tokens_1000.txt" , +        lexicon: str = "iamdb_1kwp_lex_1000.txt",          data_dir: Optional[Union[str, Path]] = None,          use_words: bool = False,          prepend_wordsep: bool = False,          special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), -        extra_symbols: Optional[Sequence[str]] = None, +        extra_symbols: Optional[Sequence[str]] = ("\n", ),      ) -> None:          super().__init__(extra_symbols)          self.wordpiece_processor = self._configure_wordpiece_processor( @@ -78,8 +78,8 @@ class WordPieceMapping(EmnistMapping):              extra_symbols,          ) +    @staticmethod      def _configure_wordpiece_processor( -        self,          num_features: int,          tokens: str,          lexicon: str, @@ -90,7 +90,7 @@ class WordPieceMapping(EmnistMapping):          extra_symbols: Optional[Sequence[str]],      ) -> Preprocessor:          data_dir = ( -            (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb") +            (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb")              if data_dir is None              else Path(data_dir)          ) @@ -138,6 +138,6 @@ class WordPieceMapping(EmnistMapping):          return self.wordpiece_processor.to_index(text)      def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: -        text = self.mapping.get_text(x) +        text = "".join([self.mapping[i] for i in x])          text = text.lower().replace(" ", "▁")          return torch.LongTensor(self.wordpiece_processor.to_index(text)) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index f53df64..8d1bedd 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence  from torch import Tensor -from text_recognizer.datasets.mappings import WordPieceMapping +from text_recognizer.data.mappings import WordPieceMapping  class WordPiece: @@ -12,14 +12,15 @@ class WordPiece:      def __init__(          self, -        num_features: int, -        tokens: str, -        lexicon: str, +        num_features: int = 1000, +        tokens: str = "iamdb_1kwp_tokens_1000.txt" , +        lexicon: str = "iamdb_1kwp_lex_1000.txt",          data_dir: Optional[Union[str, Path]] = None,          use_words: bool = False,          prepend_wordsep: bool = False,          special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), -        extra_symbols: Optional[Sequence[str]] = None, +        extra_symbols: Optional[Sequence[str]] = ("\n",), +        max_len: int = 192,      ) -> None:          self.mapping = WordPieceMapping(              num_features, @@ -31,6 +32,7 @@ class WordPiece:              special_tokens,              extra_symbols,          ) +        self.max_len = max_len      def __call__(self, x: Tensor) -> Tensor: -        return self.mapping.emnist_to_wordpiece_indices(x) +        return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index e69de29..5ac2510 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -0,0 +1,3 @@ +"""PyTorch Lightning models modules.""" +from .transformer import LitTransformerModel +from .vqvae import LitVQVAEModel diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index aeda039..88ffde6 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -40,6 +40,15 @@ class LitBaseModel(pl.LightningModule):          args = {} or criterion.args          return getattr(nn, criterion.type)(**args) +    def optimizer_zero_grad( +        self, +        epoch: int, +        batch_idx: int, +        optimizer: Type[torch.optim.Optimizer], +        optimizer_idx: int, +    ) -> None: +        optimizer.zero_grad(set_to_none=True) +      def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:          """Configures the optimizer."""          args = {} or self._optimizer.args diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py new file mode 100644 index 0000000..ef2213c --- /dev/null +++ b/text_recognizer/models/vqvae.py @@ -0,0 +1,70 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Any, Dict, Union, Tuple, Type + +from omegaconf import DictConfig, OmegaConf +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import wandb + +from text_recognizer.models.base import LitBaseModel + + +class LitVQVAEModel(LitBaseModel): +    """A PyTorch Lightning model for transformer networks.""" + +    def __init__( +        self, +        network: Type[nn.Module], +        optimizer: Union[DictConfig, Dict], +        lr_scheduler: Union[DictConfig, Dict], +        criterion: Union[DictConfig, Dict], +        monitor: str = "val_loss", +        *args: Any, +        **kwargs: Dict, +    ) -> None: +        super().__init__(network, optimizer, lr_scheduler, criterion, monitor) + +    def forward(self, data: Tensor) -> Tensor: +        """Forward pass with the transformer network.""" +        return self.network.predict(data) + +    def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None: +        """Logs prediction on image with wandb.""" +        try: +            self.logger.experiment.log( +                { +                    "val_pred_examples": [ +                        wandb.Image(data[0]), +                        wandb.Image(reconstructions[0]), +                    ] +                } +            ) +        except AttributeError: +            pass + +    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: +        """Training step.""" +        data, _ = batch +        reconstructions, vq_loss = self.network(data) +        loss = self.loss_fn(reconstructions, data) +        loss += vq_loss +        self.log("train_loss", loss) +        return loss + +    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: +        """Validation step.""" +        data, _ = batch +        reconstructions, vq_loss = self.network(data) +        loss = self.loss_fn(reconstructions, data) +        loss += vq_loss +        self.log("val_loss", loss, prog_bar=True) +        self._log_prediction(data, reconstructions) + +    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: +        """Test step.""" +        data, _ = batch +        reconstructions, vq_loss = self.network(data) +        loss = self.loss_fn(reconstructions, data) +        loss += vq_loss +        self._log_prediction(data, reconstructions) diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 979149f..41fd43f 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,2 +1,2 @@  """Network modules""" -from .image_transformer import ImageTransformer +from .vqvae import VQVAE diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index 9150b55..e23a15d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,158 +1,165 @@ -"""A CNN-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple +"""A Transformer with a cnn backbone. + +The network encodes a image with a convolutional backbone to a latent representation, +i.e. feature maps. A 2d positional encoding is applied to the feature maps for +spatial information. The resulting feature are then set to a transformer decoder +together with the target tokens. + +TODO: Local attention for lower layer in attention. + +""" +import importlib +import math +from typing import Dict, Optional, Union, Sequence, Type  from einops import rearrange +from omegaconf import DictConfig, OmegaConf  import torch  from torch import nn  from torch import Tensor -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +from text_recognizer.networks.transformer import ( +    Decoder, +    DecoderLayer, +    PositionalEncoding, +    PositionalEncoding2D, +    target_padding_mask, +) +NUM_WORD_PIECES = 1000 -class CNNTransformer(nn.Module): -    """CNN+Transfomer for image to sequence prediction.""" +class CNNTransformer(nn.Module):      def __init__(          self, -        num_encoder_layers: int, -        num_decoder_layers: int, -        hidden_dim: int, -        vocab_size: int, -        num_heads: int, -        adaptive_pool_dim: Tuple, -        expansion_dim: int, -        dropout_rate: float, -        trg_pad_index: int, -        max_len: int, -        backbone: str, -        backbone_args: Optional[Dict] = None, -        activation: str = "gelu", -        pool_kernel: Optional[Tuple[int, int]] = None, +        input_shape: Sequence[int], +        output_shape: Sequence[int], +        encoder: Union[DictConfig, Dict], +        vocab_size: Optional[int] = None, +        num_decoder_layers: int = 4, +        hidden_dim: int = 256, +        num_heads: int = 4, +        expansion_dim: int = 1024, +        dropout_rate: float = 0.1, +        transformer_activation: str = "glu",      ) -> None: -        super().__init__() -        self.trg_pad_index = trg_pad_index -        self.vocab_size = vocab_size -        self.backbone = configure_backbone(backbone, backbone_args) - -        if pool_kernel is not None: -            self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) -        else: -            self.max_pool = None - -        self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - -        self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) -        self.pos_dropout = nn.Dropout(p=dropout_rate) -        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - -        nn.init.normal_(self.character_embedding.weight, std=0.02) - -        self.adaptive_pool = ( -            nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None +        self.vocab_size = ( +            NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size          ) +        self.hidden_dim = hidden_dim +        self.max_output_length = output_shape[0] -        self.transformer = Transformer( -            num_encoder_layers, -            num_decoder_layers, -            hidden_dim, -            num_heads, -            expansion_dim, -            dropout_rate, -            activation, +        # Image backbone +        self.encoder = self._configure_encoder(encoder) +        self.feature_map_encoding = PositionalEncoding2D( +            hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]          ) -        self.head = nn.Sequential( -            # nn.Linear(hidden_dim, hidden_dim * 2), -            # activation_function(activation), -            nn.Linear(hidden_dim, vocab_size), -        ) +        # Target token embedding +        self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) +        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) -    def _create_trg_mask(self, trg: Tensor) -> Tensor: -        # Move this outside the transformer. -        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] -        trg_len = trg.shape[1] -        trg_sub_mask = torch.tril( -            torch.ones((trg_len, trg_len), device=trg.device) -        ).bool() -        trg_mask = trg_pad_mask & trg_sub_mask -        return trg_mask - -    def encoder(self, src: Tensor) -> Tensor: -        """Forward pass with the encoder of the transformer.""" -        return self.transformer.encoder(src) - -    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: -        """Forward pass with the decoder of the transformer + classification head.""" -        return self.head( -            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) +        # Transformer decoder +        self.decoder = Decoder( +            decoder_layer=DecoderLayer( +                hidden_dim=hidden_dim, +                num_heads=num_heads, +                expansion_dim=expansion_dim, +                dropout_rate=dropout_rate, +                activation=transformer_activation, +            ), +            num_layers=num_decoder_layers, +            norm=nn.LayerNorm(hidden_dim),          ) -    def extract_image_features(self, src: Tensor) -> Tensor: -        """Extracts image features with a backbone neural network. - -        It seem like the winning idea was to swap channels and width dimension and collapse -        the height dimension. The transformer is learning like a baby with this implementation!!! :D -        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D +        # Classification head +        self.head = nn.Linear(hidden_dim, self.vocab_size) -        Args: -            src (Tensor): Input tensor. +        # Initialize weights +        self._init_weights() -        Returns: -            Tensor: A input src to the transformer. +    def _init_weights(self) -> None: +        """Initialize network weights.""" +        self.trg_embedding.weight.data.uniform_(-0.1, 0.1) +        self.head.bias.data.zero_() +        self.head.weight.data.uniform_(-0.1, 0.1) -        """ -        # If batch dimension is missing, it needs to be added. -        if len(src.shape) < 4: -            src = src[(None,) * (4 - len(src.shape))] - -        src = self.backbone(src) - -        if self.max_pool is not None: -            src = self.max_pool(src) - -        if self.adaptive_pool is not None and len(src.shape) == 4: -            src = rearrange(src, "b c h w -> b w c h") -            src = self.adaptive_pool(src) -            src = src.squeeze(3) -        elif len(src.shape) == 4: -            src = rearrange(src, "b c h w -> b (h w) c") +        nn.init.kaiming_normal_( +            self.feature_map_encoding.weight.data, +            a=0, +            mode="fan_out", +            nonlinearity="relu", +        ) +        if self.feature_map_encoding.bias is not None: +            _, fan_out = nn.init._calculate_fan_in_and_fan_out( +                self.feature_map_encoding.weight.data +            ) +            bound = 1 / math.sqrt(fan_out) +            nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + +    @staticmethod +    def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: +        encoder = OmegaConf.create(encoder) +        network_module = importlib.import_module("text_recognizer.networks") +        encoder_class = getattr(network_module, encoder.type) +        return encoder_class(**encoder.args) + +    def encode(self, image: Tensor) -> Tensor: +        """Extracts image features with backbone. -        b, t, _ = src.shape +        Args: +            image (Tensor): Image(s) of handwritten text. -        src += self.src_position_embedding[:, :t] -        src = self.pos_dropout(src) +        Retuns: +            Tensor: Image features. -        return src +        Shapes: +            - image: :math: `(B, C, H, W)` +            - latent: :math: `(B, T, C)` -    def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: -        """Encodes target tensor with embedding and postion. +        """ +        # Extract image features. +        image_features = self.encoder(image) -        Args: -            trg (Tensor): Target tensor. +        # Add 2d encoding to the feature maps. +        image_features = self.feature_map_encoding(image_features) -        Returns: -            Tuple[Tensor, Tensor]: Encoded target tensor and target mask. +        # Collapse features maps height and width. +        image_features = rearrange(image_features, "b c h w -> b (h w) c") +        return image_features -        """ -        trg = self.character_embedding(trg.long()) +    def decode(self, memory: Tensor, trg: Tensor) -> Tensor: +        """Decodes image features with transformer decoder.""" +        trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) +        trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)          trg = self.trg_position_encoding(trg) -        return trg - -    def decode_image_features( -        self, image_features: Tensor, trg: Optional[Tensor] = None -    ) -> Tensor: -        """Takes images features from the backbone and decodes them with the transformer.""" -        trg_mask = self._create_trg_mask(trg) -        trg = self.target_embedding(trg) -        out = self.transformer(image_features, trg, trg_mask=trg_mask) - +        out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)          logits = self.head(out)          return logits -    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: -        """Forward pass with CNN transfomer.""" -        image_features = self.extract_image_features(x) -        logits = self.decode_image_features(image_features, trg) -        return logits +    def predict(self, image: Tensor) -> Tensor: +        """Transcribes text in image(s).""" +        bsz = image.shape[0] +        image_features = self.encode(image) + +        output_tokens = ( +            (torch.ones((bsz, self.max_output_length)) * self.pad_index) +            .type_as(image) +            .long() +        ) +        output_tokens[:, 0] = self.start_index +        for i in range(1, self.max_output_length): +            trg = output_tokens[:, :i] +            output = self.decode(image_features, trg) +            output = torch.argmax(output, dim=-1) +            output_tokens[:, i] = output[-1:] + +        # Set all tokens after end token to be padding. +        for i in range(1, self.max_output_length): +            indices = output_tokens[:, i - 1] == self.end_index | ( +                output_tokens[:, i - 1] == self.pad_index +            ) +            output_tokens[indices, i] = self.pad_index + +        return output_tokens diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py deleted file mode 100644 index a6aaca4..0000000 --- a/text_recognizer/networks/image_transformer.py +++ /dev/null @@ -1,165 +0,0 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( -    Decoder, -    DecoderLayer, -    PositionalEncoding, -    PositionalEncoding2D, -    target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class ImageTransformer(nn.Module): -    def __init__( -        self, -        input_shape: Sequence[int], -        output_shape: Sequence[int], -        encoder: Union[DictConfig, Dict], -        vocab_size: Optional[int] = None, -        num_decoder_layers: int = 4, -        hidden_dim: int = 256, -        num_heads: int = 4, -        expansion_dim: int = 1024, -        dropout_rate: float = 0.1, -        transformer_activation: str = "glu", -    ) -> None: -        self.vocab_size = ( -            NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size -        ) -        self.hidden_dim = hidden_dim -        self.max_output_length = output_shape[0] - -        # Image backbone -        self.encoder = self._configure_encoder(encoder) -        self.feature_map_encoding = PositionalEncoding2D( -            hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] -        ) - -        # Target token embedding -        self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) -        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - -        # Transformer decoder -        self.decoder = Decoder( -            decoder_layer=DecoderLayer( -                hidden_dim=hidden_dim, -                num_heads=num_heads, -                expansion_dim=expansion_dim, -                dropout_rate=dropout_rate, -                activation=transformer_activation, -            ), -            num_layers=num_decoder_layers, -            norm=nn.LayerNorm(hidden_dim), -        ) - -        # Classification head -        self.head = nn.Linear(hidden_dim, self.vocab_size) - -        # Initialize weights -        self._init_weights() - -    def _init_weights(self) -> None: -        """Initialize network weights.""" -        self.trg_embedding.weight.data.uniform_(-0.1, 0.1) -        self.head.bias.data.zero_() -        self.head.weight.data.uniform_(-0.1, 0.1) - -        nn.init.kaiming_normal_( -            self.feature_map_encoding.weight.data, -            a=0, -            mode="fan_out", -            nonlinearity="relu", -        ) -        if self.feature_map_encoding.bias is not None: -            _, fan_out = nn.init._calculate_fan_in_and_fan_out( -                self.feature_map_encoding.weight.data -            ) -            bound = 1 / math.sqrt(fan_out) -            nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) - -    @staticmethod -    def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: -        encoder = OmegaConf.create(encoder) -        network_module = importlib.import_module("text_recognizer.networks") -        encoder_class = getattr(network_module, encoder.type) -        return encoder_class(**encoder.args) - -    def encode(self, image: Tensor) -> Tensor: -        """Extracts image features with backbone. - -        Args: -            image (Tensor): Image(s) of handwritten text. - -        Retuns: -            Tensor: Image features. - -        Shapes: -            - image: :math: `(B, C, H, W)` -            - latent: :math: `(B, T, C)` - -        """ -        # Extract image features. -        image_features = self.encoder(image) - -        # Add 2d encoding to the feature maps. -        image_features = self.feature_map_encoding(image_features) - -        # Collapse features maps height and width. -        image_features = rearrange(image_features, "b c h w -> b (h w) c") -        return image_features - -    def decode(self, memory: Tensor, trg: Tensor) -> Tensor: -        """Decodes image features with transformer decoder.""" -        trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) -        trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) -        trg = self.trg_position_encoding(trg) -        out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) -        logits = self.head(out) -        return logits - -    def predict(self, image: Tensor) -> Tensor: -        """Transcribes text in image(s).""" -        bsz = image.shape[0] -        image_features = self.encode(image) - -        output_tokens = ( -            (torch.ones((bsz, self.max_output_length)) * self.pad_index) -            .type_as(image) -            .long() -        ) -        output_tokens[:, 0] = self.start_index -        for i in range(1, self.max_output_length): -            trg = output_tokens[:, :i] -            output = self.decode(image_features, trg) -            output = torch.argmax(output, dim=-1) -            output_tokens[:, i] = output[-1:] - -        # Set all tokens after end token to be padding. -        for i in range(1, self.max_output_length): -            indices = output_tokens[:, i - 1] == self.end_index | ( -                output_tokens[:, i - 1] == self.pad_index -            ) -            output_tokens[indices, i] = self.pad_index - -        return output_tokens diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index c33f419..da7553d 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d):  def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:      """3x3 convolution with batch norm.""" -    conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) +    conv3x3 = partial( +        Conv2dAuto, +        kernel_size=3, +        bias=False, +    )      return nn.Sequential(          conv3x3(in_channels, out_channels, *args, **kwargs),          nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index d7e3d08..b10f93a 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,7 +392,12 @@ def load_transducer_loss(          transitions = gtn.load(str(processed_path / transitions))      preprocessor = Preprocessor( -        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, +        data_dir, +        num_features, +        tokens_path, +        lexicon_path, +        use_words, +        prepend_wordsep,      )      num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 8847aba..93a1e43 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,7 +44,12 @@ class Decoder(nn.Module):          # Configure encoder.          self.decoder = self._build_decoder( -            channels, kernel_sizes, strides, num_residual_layers, activation, dropout, +            channels, +            kernel_sizes, +            strides, +            num_residual_layers, +            activation, +            dropout,          )      def _build_decompression_block( @@ -72,8 +77,10 @@ class Decoder(nn.Module):                  )              ) -            if i < len(self.upsampling): -                modules.append(nn.Upsample(size=self.upsampling[i]),) +            if self.upsampling and i < len(self.upsampling): +                modules.append( +                    nn.Upsample(size=self.upsampling[i]), +                )              if dropout is not None:                  modules.append(dropout) @@ -102,7 +109,12 @@ class Decoder(nn.Module):      ) -> nn.Sequential:          self.res_block.append( -            nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) +            nn.Conv2d( +                self.embedding_dim, +                channels[0], +                kernel_size=1, +                stride=1, +            )          )          # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index d3adac5..b0cceed 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,5 +1,5 @@  """CNN encoder for the VQ-VAE.""" -from typing import List, Optional, Tuple, Type +from typing import Sequence, Optional, Tuple, Type  import torch  from torch import nn @@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer  class _ResidualBlock(nn.Module):      def __init__( -        self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], +        self, +        in_channels: int, +        out_channels: int, +        dropout: Optional[Type[nn.Module]],      ) -> None:          super().__init__()          self.block = [ @@ -36,9 +39,9 @@ class Encoder(nn.Module):      def __init__(          self,          in_channels: int, -        channels: List[int], -        kernel_sizes: List[int], -        strides: List[int], +        channels: Sequence[int], +        kernel_sizes: Sequence[int], +        strides: Sequence[int],          num_residual_layers: int,          embedding_dim: int,          num_embeddings: int, @@ -77,12 +80,12 @@ class Encoder(nn.Module):              self.num_embeddings, self.embedding_dim, self.beta          ) +    @staticmethod      def _build_compression_block( -        self,          in_channels: int,          channels: int, -        kernel_sizes: List[int], -        strides: List[int], +        kernel_sizes: Sequence[int], +        strides: Sequence[int],          activation: Type[nn.Module],          dropout: Optional[Type[nn.Module]],      ) -> nn.ModuleList: @@ -109,8 +112,8 @@ class Encoder(nn.Module):          self,          in_channels: int,          channels: int, -        kernel_sizes: List[int], -        strides: List[int], +        kernel_sizes: Sequence[int], +        strides: Sequence[int],          num_residual_layers: int,          activation: Type[nn.Module],          dropout: Optional[Type[nn.Module]], @@ -135,7 +138,12 @@ class Encoder(nn.Module):          )          encoder.append( -            nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) +            nn.Conv2d( +                channels[-1], +                self.embedding_dim, +                kernel_size=1, +                stride=1, +            )          )          return nn.Sequential(*encoder) diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 50448b4..1f08e5e 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,8 +1,7 @@  """The VQ-VAE.""" -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple -import torch  from torch import nn  from torch import Tensor @@ -25,6 +24,8 @@ class VQVAE(nn.Module):          beta: float = 0.25,          activation: str = "leaky_relu",          dropout_rate: float = 0.0, +        *args: Any, +        **kwargs: Dict,      ) -> None:          super().__init__()  |