From 34098ccbbbf6379c0bd29a987440b8479c743746 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 29 Jul 2021 23:59:52 +0200 Subject: Configs, refactor with attrs, fix attr bug in iam --- text_recognizer/criterions/label_smoothing.py | 42 +++++ text_recognizer/criterions/label_smoothing_loss.py | 42 ----- text_recognizer/data/base_dataset.py | 1 + text_recognizer/data/emnist.py | 2 +- text_recognizer/data/iam_extended_paragraphs.py | 23 +-- text_recognizer/data/iam_lines.py | 6 +- text_recognizer/data/iam_paragraphs.py | 7 +- text_recognizer/data/iam_synthetic_paragraphs.py | 12 +- text_recognizer/models/base.py | 31 ++-- text_recognizer/models/transformer.py | 26 +-- text_recognizer/networks/base.py | 18 ++ text_recognizer/networks/cnn_tranformer.py | 202 --------------------- text_recognizer/networks/conv_transformer.py | 201 ++++++++++++++++++++ 13 files changed, 302 insertions(+), 311 deletions(-) create mode 100644 text_recognizer/criterions/label_smoothing.py delete mode 100644 text_recognizer/criterions/label_smoothing_loss.py create mode 100644 text_recognizer/networks/base.py delete mode 100644 text_recognizer/networks/cnn_tranformer.py create mode 100644 text_recognizer/networks/conv_transformer.py (limited to 'text_recognizer') diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py new file mode 100644 index 0000000..40a7609 --- /dev/null +++ b/text_recognizer/criterions/label_smoothing.py @@ -0,0 +1,42 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """Label smoothing cross entropy loss.""" + + def __init__( + self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 + ) -> None: + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super().__init__() + + smoothing_value = label_smoothing / (vocab_size - 2) + one_hot = torch.full((vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): Predictions from the network. + targets (Tensor): Ground truth. + + Shapes: + outpus: Batch size x num classes + targets: Batch size + + Returns: + Tensor: Label smoothing loss. + """ + model_prob = self.one_hot.repeat(targets.size(0), 1) + model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) + model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py deleted file mode 100644 index 40a7609..0000000 --- a/text_recognizer/criterions/label_smoothing_loss.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class LabelSmoothingLoss(nn.Module): - """Label smoothing cross entropy loss.""" - - def __init__( - self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 - ) -> None: - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index - super().__init__() - - smoothing_value = label_smoothing / (vocab_size - 2) - one_hot = torch.full((vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - - self.confidence = 1.0 - label_smoothing - - def forward(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): Predictions from the network. - targets (Tensor): Ground truth. - - Shapes: - outpus: Batch size x num classes - targets: Batch size - - Returns: - Tensor: Label smoothing loss. - """ - model_prob = self.one_hot.repeat(targets.size(0), 1) - model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) - model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) - return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 4318dfb..c26f1c9 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -29,6 +29,7 @@ class BaseDataset(Dataset): super().__init__() def __attrs_post_init__(self) -> None: + # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index d51a42a..2d0ac29 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -46,7 +46,7 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - train_fraction: float = attr.ib() + train_fraction: float = attr.ib(default=0.8) transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 886e37e..58c7369 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -13,23 +13,24 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs @attr.s(auto_attribs=True) class IAMExtendedParagraphs(BaseDataModule): - train_fraction: float = attr.ib() + augment: bool = attr.ib(default=True) + train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( - self.batch_size, - self.num_workers, - self.train_fraction, - self.augment, - self.word_pieces, + batch_size=self.batch_size, + num_workers=self.num_workers, + train_fraction=self.train_fraction, + augment=self.augment, + word_pieces=self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - self.batch_size, - self.num_workers, - self.train_fraction, - self.augment, - self.word_pieces, + batch_size=self.batch_size, + num_workers=self.num_workers, + train_fraction=self.train_fraction, + augment=self.augment, + word_pieces=self.word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index e45e5c8..705cfa3 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -34,6 +34,7 @@ SEED = 4711 PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" IMAGE_HEIGHT = 56 IMAGE_WIDTH = 1024 +MAX_LABEL_LENGTH = 89 @attr.s(auto_attribs=True) @@ -42,11 +43,12 @@ class IAMLines(BaseDataModule): augment: bool = attr.ib(default=True) fraction: float = attr.ib(default=0.8) + dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)) + output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) def __attrs_post_init__(self) -> None: + # TODO: refactor this self.mapping, self.inverse_mapping, _ = emnist_mapping() - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (89, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index bdfb490..9977978 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -41,6 +41,8 @@ class IAMParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) + dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)) + output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping( @@ -49,11 +51,6 @@ class IAMParagraphs(BaseDataModule): if self.word_pieces: self.mapping = WordPieceMapping() - self.train_fraction = train_fraction - - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) - def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 00fa2b6..a3697e7 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,6 +2,7 @@ import random from typing import Any, List, Sequence, Tuple +import attr from loguru import logger import numpy as np from PIL import Image @@ -33,19 +34,10 @@ PROCESSED_DATA_DIRNAME = ( ) +@attr.s(auto_attribs=True) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces) - def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f95df0f..3b83056 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type import attr import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig -import pytorch_lightning as LightningModule +from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.networks.base import BaseNetwork + @attr.s class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - network: Type[nn.Module] = attr.ib() + def __attrs_pre_init__(self) -> None: + super().__init__() + + network: Type[BaseNetwork] = attr.ib() criterion_config: DictConfig = attr.ib(converter=DictConfig) optimizer_config: DictConfig = attr.ib(converter=DictConfig) lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) @@ -24,23 +29,13 @@ class BaseLitModel(LightningModule): interval: str = attr.ib() monitor: str = attr.ib(default="val/loss") - loss_fn = attr.ib(init=False) - - train_acc = attr.ib(init=False) - val_acc = attr.ib(init=False) - test_acc = attr.ib(init=False) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def __attrs_post_init__(self) -> None: - self.loss_fn = self._configure_criterion() + loss_fn: Type[nn.Module] = attr.ib(init=False) - # Accuracy metric - self.train_acc = torchmetrics.Accuracy() - self.val_acc = torchmetrics.Accuracy() - self.test_acc = torchmetrics.Accuracy() + train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + @loss_fn.default def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 8c9fe8a..f5cb491 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,13 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Union, Tuple, Type +from typing import Dict, List, Optional, Sequence, Union, Tuple, Type import attr import hydra from omegaconf import DictConfig from torch import nn, Tensor -from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping_config: DictConfig = attr.ib(converter=DictConfig) + ignore_tokens: Sequence[str] = attr.ib(default=("", "", "

",)) + val_cer: CharacterErrorRate = attr.ib(init=False) + test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.mapping, ignore_tokens = self._configure_mapping() - self.val_cer = CharacterErrorRate(ignore_tokens) - self.test_cer = CharacterErrorRate(ignore_tokens) + self.val_cer = CharacterErrorRate(self.ignore_tokens) + self.test_cer = CharacterErrorRate(self.ignore_tokens) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - @staticmethod - def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: - """Configure mapping.""" - # TODO: Fix me!!! - # Load config with hydra - mapping, inverse_mapping, _ = emnist_mapping(["\n"]) - start_index = inverse_mapping[""] - end_index = inverse_mapping[""] - pad_index = inverse_mapping["

"] - ignore_tokens = [start_index, end_index, pad_index] - # TODO: add case for sentence pieces - return mapping, ignore_tokens - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py new file mode 100644 index 0000000..07b6a32 --- /dev/null +++ b/text_recognizer/networks/base.py @@ -0,0 +1,18 @@ +"""Base network with required methods.""" +from abc import abstractmethod + +import attr +from torch import nn, Tensor + + +@attr.s +class BaseNetwork(nn.Module): + """Base network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + @abstractmethod + def predict(self, x: Tensor) -> Tensor: + """Return token indices for predictions.""" + ... diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py deleted file mode 100644 index ce7ec43..0000000 --- a/text_recognizer/networks/cnn_tranformer.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Vision transformer for character recognition.""" -import math -from typing import Tuple, Type - -import attr -import torch -from torch import nn, Tensor - -from text_recognizer.data.mappings import AbstractMapping -from text_recognizer.networks.encoders.efficientnet import EfficientNet -from text_recognizer.networks.transformer.layers import Decoder -from text_recognizer.networks.transformer.positional_encodings import ( - PositionalEncoding, - PositionalEncoding2D, -) - - -@attr.s -class Reader(nn.Module): - def __attrs_pre_init__(self) -> None: - super().__init__() - - # Parameters and placeholders, - input_dims: Tuple[int, int, int] = attr.ib() - hidden_dim: int = attr.ib() - dropout_rate: float = attr.ib() - max_output_len: int = attr.ib() - num_classes: int = attr.ib() - padding_idx: int = attr.ib() - start_token: str = attr.ib() - start_index: int = attr.ib(init=False) - end_token: str = attr.ib() - end_index: int = attr.ib(init=False) - pad_token: str = attr.ib() - pad_index: int = attr.ib(init=False) - - # Modules. - encoder: EfficientNet = attr.ib() - decoder: Decoder = attr.ib() - latent_encoder: nn.Sequential = attr.ib(init=False) - token_embedding: nn.Embedding = attr.ib(init=False) - token_pos_encoder: PositionalEncoding = attr.ib(init=False) - head: nn.Linear = attr.ib(init=False) - mapping: Type[AbstractMapping] = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) - # Latent projector for down sampling number of filters and 2d - # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), - nn.Flatten(start_dim=2), - ) - - # Token embedding. - self.token_embedding = nn.Embedding( - num_embeddings=self.num_classes, embedding_dim=self.hidden_dim - ) - - # Positional encoding for decoder tokens. - self.token_pos_encoder = PositionalEncoding( - hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate - ) - # Head - self.head = nn.Linear( - in_features=self.hidden_dim, out_features=self.num_classes - ) - - # Initalize weights for encoder. - self.init_weights() - - def init_weights(self) -> None: - """Initalize weights for decoder network and head.""" - bound = 0.1 - self.token_embedding.weight.data.uniform_(-bound, bound) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-bound, bound) - # TODO: Initalize encoder? - - def encode(self, x: Tensor) -> Tensor: - """Encodes an image into a latent feature vector. - - Args: - x (Tensor): Image tensor. - - Shape: - - x: :math: `(B, C, H, W)` - - z: :math: `(B, Sx, E)` - - where Sx is the length of the flattened feature maps projected from - the encoder. E latent dimension for each pixel in the projected - feature maps. - - Returns: - Tensor: A Latent embedding of the image. - """ - z = self.encoder(x) - z = self.latent_encoder(z) - - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] - z = z.permute(0, 2, 1) - return z - - def decode(self, z: Tensor, context: Tensor) -> Tensor: - """Decodes latent images embedding into word pieces. - - Args: - z (Tensor): Latent images embedding. - context (Tensor): Word embeddings. - - Shapes: - - z: :math: `(B, Sx, E)` - - context: :math: `(B, Sy)` - - out: :math: `(B, Sy, T)` - - where Sy is the length of the output and T is the number of tokens. - - Returns: - Tensor: Sequence of word piece embeddings. - """ - context_mask = context != self.padding_idx - context = self.token_embedding(context) * math.sqrt(self.hidden_dim) - context = self.token_pos_encoder(context) - out = self.decoder(x=context, context=z, mask=context_mask) - logits = self.head(out) - return logits - - def forward(self, x: Tensor, context: Tensor) -> Tensor: - """Encodes images into word piece logtis. - - Args: - x (Tensor): Input image(s). - context (Tensor): Target word embeddings. - - Shapes: - - x: :math: `(B, C, H, W)` - - context: :math: `(B, Sy, T)` - - where B is the batch size, C is the number of input channels, H is - the image height and W is the image width. - - Returns: - Tensor: Sequence of logits. - """ - z = self.encode(x) - logits = self.decode(z, context) - return logits - - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z = self.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index - ) - output[idx, i] = self.pad_index - - return output diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py new file mode 100644 index 0000000..4acdc36 --- /dev/null +++ b/text_recognizer/networks/conv_transformer.py @@ -0,0 +1,201 @@ +"""Vision transformer for character recognition.""" +import math +from typing import Tuple, Type + +import attr +import torch +from torch import nn, Tensor + +from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.base import BaseNetwork +from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.transformer.layers import Decoder +from text_recognizer.networks.transformer.positional_encodings import ( + PositionalEncoding, + PositionalEncoding2D, +) + + +@attr.s(auto_attribs=True) +class ConvTransformer(BaseNetwork): + # Parameters and placeholders, + input_dims: Tuple[int, int, int] = attr.ib() + hidden_dim: int = attr.ib() + dropout_rate: float = attr.ib() + max_output_len: int = attr.ib() + num_classes: int = attr.ib() + start_token: str = attr.ib() + start_index: Tensor = attr.ib(init=False) + end_token: str = attr.ib() + end_index: Tensor = attr.ib(init=False) + pad_token: str = attr.ib() + pad_index: Tensor = attr.ib(init=False) + + # Modules. + encoder: EfficientNet = attr.ib() + decoder: Decoder = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + latent_encoder: nn.Sequential = attr.ib(init=False) + token_embedding: nn.Embedding = attr.ib(init=False) + token_pos_encoder: PositionalEncoding = attr.ib(init=False) + head: nn.Linear = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + + # Latent projector for down sampling number of filters and 2d + # positional encoding. + self.latent_encoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ), + PositionalEncoding2D( + hidden_dim=self.hidden_dim, + max_h=self.input_dims[1], + max_w=self.input_dims[2], + ), + nn.Flatten(start_dim=2), + ) + + # Token embedding. + self.token_embedding = nn.Embedding( + num_embeddings=self.num_classes, embedding_dim=self.hidden_dim + ) + + # Positional encoding for decoder tokens. + self.token_pos_encoder = PositionalEncoding( + hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate + ) + # Head + self.head = nn.Linear( + in_features=self.hidden_dim, out_features=self.num_classes + ) + + # Initalize weights for encoder. + self.init_weights() + + def init_weights(self) -> None: + """Initalize weights for decoder network and head.""" + bound = 0.1 + self.token_embedding.weight.data.uniform_(-bound, bound) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-bound, bound) + # TODO: Initalize encoder? + + def encode(self, x: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + x (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, E)` + + where Sx is the length of the flattened feature maps projected from + the encoder. E latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(x) + z = self.latent_encoder(z) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z + + def decode(self, z: Tensor, context: Tensor) -> Tensor: + """Decodes latent images embedding into word pieces. + + Args: + z (Tensor): Latent images embedding. + context (Tensor): Word embeddings. + + Shapes: + - z: :math: `(B, Sx, E)` + - context: :math: `(B, Sy)` + - out: :math: `(B, Sy, T)` + + where Sy is the length of the output and T is the number of tokens. + + Returns: + Tensor: Sequence of word piece embeddings. + """ + context_mask = context != self.pad_index + context = self.token_embedding(context) * math.sqrt(self.hidden_dim) + context = self.token_pos_encoder(context) + out = self.decoder(x=context, context=z, mask=context_mask) + logits = self.head(out) + return logits + + def forward(self, x: Tensor, context: Tensor) -> Tensor: + """Encodes images into word piece logtis. + + Args: + x (Tensor): Input image(s). + context (Tensor): Target word embeddings. + + Shapes: + - x: :math: `(B, C, H, W)` + - context: :math: `(B, Sy, T)` + + where B is the batch size, C is the number of input channels, H is + the image height and W is the image width. + + Returns: + Tensor: Sequence of logits. + """ + z = self.encode(x) + logits = self.decode(z, context) + return logits + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z = self.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + context = output[:, :i] # (bsz, i) + logits = self.decode(z, context) # (i, bsz, c) + tokens = torch.argmax(logits, dim=-1) # (i, bsz) + output[:, i : i + 1] = tokens[-1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) + output[idx, i] = self.pad_index + + return output -- cgit v1.2.3-70-g09d2