diff options
| -rw-r--r-- | text_recognizer/networks/quantizer/__init__.py | 0 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/codebook.py | 96 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/kmeans.py | 32 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/quantizer.py | 59 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/utils.py | 26 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 141 | ||||
| -rw-r--r-- | training/conf/experiment/vq_transformer_lines.yaml | 149 | ||||
| -rw-r--r-- | training/conf/network/quantizer.yaml | 12 | 
8 files changed, 374 insertions, 141 deletions
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/text_recognizer/networks/quantizer/__init__.py diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py new file mode 100644 index 0000000..cb9bc59 --- /dev/null +++ b/text_recognizer/networks/quantizer/codebook.py @@ -0,0 +1,96 @@ +"""Codebook module.""" +from typing import Tuple + +import attr +from einops import rearrange +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.networks.quantizer.kmeans import kmeans +from text_recognizer.networks.quantizer.utils import ( +    ema_inplace, +    norm, +    sample_vectors, +) + + +@attr.s(eq=False) +class CosineSimilarityCodebook(nn.Module): +    """Cosine similarity codebook.""" + +    dim: int = attr.ib() +    codebook_size: int = attr.ib() +    kmeans_init: bool = attr.ib(default=False) +    kmeans_iters: int = attr.ib(default=10) +    decay: float = attr.ib(default=0.8) +    eps: float = attr.ib(default=1.0e-5) +    threshold_dead: int = attr.ib(default=2) + +    def __attrs_pre_init__(self) -> None: +        super().__init__() + +    def __attrs_post_init__(self) -> None: +        if not self.kmeans_init: +            embeddings = norm(torch.randn(self.codebook_size, self.dim)) +        else: +            embeddings = torch.zeros(self.codebook_size, self.dim) +        self.register_buffer("initalized", Tensor([not self.kmeans_init])) +        self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) +        self.register_buffer("embeddings", embeddings) + +    def _initalize_embedding(self, data: Tensor) -> None: +        embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) +        self.embeddings.data.copy_(embeddings) +        self.cluster_size.data.copy_(cluster_size) +        self.initalized.data.copy_(Tensor([True])) + +    def _replace(self, samples: Tensor, mask: Tensor) -> None: +        samples = norm(samples) +        modified_codebook = torch.where( +            mask[..., None], +            sample_vectors(samples, self.codebook_size), +            self.embeddings, +        ) +        self.embeddings.data.copy_(modified_codebook) + +    def _replace_dead_codes(self, batch_samples: Tensor) -> None: +        if self.threshold_dead == 0: +            return +        dead_codes = self.cluster_size < self.threshold_dead +        if not torch.any(dead_codes): +            return +        batch_samples = rearrange(batch_samples, "... d -> (...) d") +        self._replace(batch_samples, mask=dead_codes) + +    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: +        """Quantizes tensor.""" +        shape = x.shape +        flatten = rearrange(x, "... d -> (...) d") +        flatten = norm(flatten) + +        if not self.initalized: +            self._initalize_embedding(flatten) + +        embeddings = norm(self.embeddings) +        dist = flatten @ embeddings.t() +        indices = dist.max(dim=-1).indices +        one_hot = F.one_hot(indices, self.codebook_size).type_as(x) +        indices = indices.view(*shape[:-1]) + +        quantized = F.embedding(indices, self.embeddings) + +        if self.training: +            bins = one_hot.sum(0) +            ema_inplace(self.cluster_size, bins, self.decay) +            zero_mask = bins == 0 +            bins = bins.masked_fill(zero_mask, 1.0) + +            embed_sum = flatten.t() @ one_hot +            embed_norm = (embed_sum / bins.unsqueeze(0)).t() +            embed_norm = norm(embed_norm) +            embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm) +            ema_inplace(self.embeddings, embed_norm, self.decay) +            self._replace_dead_codes(x) + +        return quantized, indices diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py new file mode 100644 index 0000000..a34c381 --- /dev/null +++ b/text_recognizer/networks/quantizer/kmeans.py @@ -0,0 +1,32 @@ +"""K-means clustering for embeddings.""" +from typing import Tuple + +from einops import repeat +import torch +from torch import Tensor + +from text_recognizer.networks.quantizer.utils import norm, sample_vectors + + +def kmeans( +    samples: Tensor, num_clusters: int, num_iters: int = 10 +) -> Tuple[Tensor, Tensor]: +    """Compute k-means clusters.""" +    D = samples.shape[-1] + +    means = sample_vectors(samples, num_clusters) + +    for _ in range(num_iters): +        dists = samples @ means.t() +        buckets = dists.max(dim=-1).indices +        bins = torch.bincount(buckets, minlength=num_clusters) +        zero_mask = bins == 0 +        bins_min_clamped = bins.masked_fill(zero_mask, 1) + +        new_means = buckets.new_zeros(num_clusters, D).type_as(samples) +        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples) +        new_means /= bins_min_clamped[..., None] +        new_means = norm(new_means) +        means = torch.where(zero_mask[..., None], means, new_means) + +    return means, bins diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py new file mode 100644 index 0000000..3e8f0b2 --- /dev/null +++ b/text_recognizer/networks/quantizer/quantizer.py @@ -0,0 +1,59 @@ +"""Implementation of a Vector Quantized Variational AutoEncoder. + +Reference: +https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py +""" +from typing import Tuple, Type + +import attr +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +@attr.s(eq=False) +class VectorQuantizer(nn.Module): +    """Vector quantizer.""" + +    input_dim: int = attr.ib() +    codebook: Type[nn.Module] = attr.ib() +    commitment: float = attr.ib(default=1.0) +    project_in: nn.Linear = attr.ib(default=None, init=False) +    project_out: nn.Linear = attr.ib(default=None, init=False) + +    def __attrs_pre_init__(self) -> None: +        super().__init__() + +    def __attrs_post_init__(self) -> None: +        require_projection = self.codebook.dim != self.input_dim +        self.project_in = ( +            nn.Linear(self.input_dim, self.codebook.dim) +            if require_projection +            else nn.Identity() +        ) +        self.project_out = ( +            nn.Linear(self.codebook.dim, self.input_dim) +            if require_projection +            else nn.Identity() +        ) + +    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +        """Quantizes latent vectors.""" +        H, W = x.shape[-2:] +        x = rearrange(x, "b d h w -> b (h w) d") +        x = self.project_in(x) + +        quantized, indices = self.codebook(x) + +        if self.training: +            commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment +            quantized = x + (quantized - x).detach() +        else: +            commitment_loss = torch.tensor([0.0]).type_as(x) + +        quantized = self.project_out(quantized) +        quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W) + +        return quantized, indices, commitment_loss diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py new file mode 100644 index 0000000..0502d49 --- /dev/null +++ b/text_recognizer/networks/quantizer/utils.py @@ -0,0 +1,26 @@ +"""Helper functions for quantization.""" +from typing import Tuple + +import torch +from torch import Tensor +import torch.nn.functional as F + + +def sample_vectors(samples: Tensor, num: int) -> Tensor: +    """Subsamples a set of vectors.""" +    B, device = samples.shape[0], samples.device +    if B >= num: +        indices = torch.randperm(B, device=device)[:num] +    else: +        indices = torch.randint(0, B, (num,), device=device)[:num] +    return samples[indices] + + +def norm(t: Tensor) -> Tensor: +    """Applies L2-normalization.""" +    return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: +    """Applies exponential moving average.""" +    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py deleted file mode 100644 index bba9b60..0000000 --- a/text_recognizer/networks/vqvae/quantizer.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Implementation of a Vector Quantized Variational AutoEncoder. - -Reference: -https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py -""" -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class EmbeddingEMA(nn.Module): -    """Embedding for Exponential Moving Average (EMA).""" - -    def __init__(self, num_embeddings: int, embedding_dim: int) -> None: -        super().__init__() -        weight = torch.zeros(num_embeddings, embedding_dim) -        nn.init.kaiming_uniform_(weight, nonlinearity="linear") -        self.register_buffer("weight", weight) -        self.register_buffer("cluster_size", torch.zeros(num_embeddings)) -        self.register_buffer("weight_avg", weight.clone()) - - -class VectorQuantizer(nn.Module): -    """The codebook that contains quantized vectors.""" - -    def __init__( -        self, num_embeddings: int, embedding_dim: int, decay: float = 0.99 -    ) -> None: -        super().__init__() -        self.num_embeddings = num_embeddings -        self.embedding_dim = embedding_dim -        self.decay = decay -        self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim) - -    def _discretization_bottleneck(self, latent: Tensor) -> Tensor: -        """Computes the code nearest to the latent representation. - -        First we compute the posterior categorical distribution, and then map -        the latent representation to the nearest element of the embedding. - -        Args: -            latent (Tensor): The latent representation. - -        Shape: -            - latent :math:`(B x H x W, D)` - -        Returns: -            Tensor: The quantized embedding vector. - -        """ -        # Store latent shape. -        b, h, w, d = latent.shape - -        # Flatten the latent representation to 2D. -        latent = rearrange(latent, "b h w d -> (b h w) d") - -        # Compute the L2 distance between the latents and the embeddings. -        l2_distance = ( -            torch.sum(latent ** 2, dim=1, keepdim=True) -            + torch.sum(self.embedding.weight ** 2, dim=1) -            - 2 * latent @ self.embedding.weight.t() -        )  # [BHW x K] - -        # Find the embedding k nearest to each latent. -        encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1)  # [BHW, 1] - -        # Convert to one-hot encodings, aka discrete bottleneck. -        one_hot_encoding = torch.zeros( -            encoding_indices.shape[0], self.num_embeddings, device=latent.device -        ) -        one_hot_encoding.scatter_(1, encoding_indices, 1)  # [BHW x K] - -        # Embedding quantization. -        quantized_latent = one_hot_encoding @ self.embedding.weight  # [BHW, D] -        quantized_latent = rearrange( -            quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w -        ) -        if self.training: -            self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) - -        return quantized_latent - -    def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: -        """Computes the EMA update to the codebook.""" -        batch_cluster_size = one_hot_encoding.sum(axis=0) -        batch_embedding_avg = (latent.t() @ one_hot_encoding).t() -        self.embedding.cluster_size.data.mul_(self.decay).add_( -            batch_cluster_size, alpha=1 - self.decay -        ) -        self.embedding.weight_avg.data.mul_(self.decay).add_( -            batch_embedding_avg, alpha=1 - self.decay -        ) -        new_embedding = self.embedding.weight_avg / ( -            self.embedding.cluster_size + 1.0e-5 -        ).unsqueeze(1) -        self.embedding.weight.data.copy_(new_embedding) - -    def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: -        """Vector Quantization loss. - -        The vector quantization algorithm allows us to create a codebook. The VQ -        algorithm works by moving the embedding vectors towards the encoder outputs. - -        The embedding loss moves the embedding vector towards the encoder outputs. The -        .detach() works as the stop gradient (sg) described in the paper. - -        Because the volume of the embedding space is dimensionless, it can arbitarily -        grow if the embeddings are not trained as fast as the encoder parameters. To -        mitigate this, a commitment loss is added in the second term which makes sure -        that the encoder commits to an embedding and that its output does not grow. - -        Args: -            latent (Tensor): The encoder output. -            quantized_latent (Tensor): The quantized latent. - -        Returns: -            Tensor: The combinded VQ loss. - -        """ -        loss = F.mse_loss(quantized_latent.detach(), latent) -        return loss - -    def forward(self, latent: Tensor) -> Tensor: -        """Forward pass that returns the quantized vector and the vq loss.""" -        # Rearrange latent representation s.t. the hidden dim is at the end. -        latent = rearrange(latent, "b d h w -> b h w d") - -        # Maps latent to the nearest code in the codebook. -        quantized_latent = self._discretization_bottleneck(latent) - -        loss = self._commitment_loss(latent, quantized_latent) - -        # Add residue to the quantized latent. -        quantized_latent = latent + (quantized_latent - latent).detach() - -        # Rearrange the quantized shape back to the original shape. -        quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") - -        return quantized_latent, loss diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml new file mode 100644 index 0000000..bbe1178 --- /dev/null +++ b/training/conf/experiment/vq_transformer_lines.yaml @@ -0,0 +1,149 @@ +# @package _global_ + +defaults: +  - override /mapping: null +  - override /criterion: cross_entropy +  - override /callbacks: htr +  - override /datamodule: iam_lines +  - override /network: null +  - override /model: null +  - override /lr_schedulers: null +  - override /optimizers: null + +epochs: &epochs 512 +ignore_index: &ignore_index 3 +num_classes: &num_classes 57 +max_output_len: &max_output_len 89 +summary: [[1, 1, 56, 1024], [1, 89]] + +criterion: +  ignore_index: *ignore_index +     +mapping: &mapping +  mapping: +    _target_: text_recognizer.data.mappings.emnist.EmnistMapping + +callbacks: +  stochastic_weight_averaging: +    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging +    swa_epoch_start: 0.75 +    swa_lrs: 1.0e-5 +    annealing_epochs: 10 +    annealing_strategy: cos +    device: null + +optimizers: +  madgrad: +    _target_: madgrad.MADGRAD +    lr: 3.0e-4 +    momentum: 0.9 +    weight_decay: 0 +    eps: 1.0e-6 +    parameters: network + +lr_schedulers: +  network: +    _target_: torch.optim.lr_scheduler.CosineAnnealingLR +    T_max: *epochs +    eta_min: 1.0e-5 +    last_epoch: -1 +    interval: epoch +    monitor: val/loss + +datamodule: +  batch_size: 16 +  num_workers: 12 +  train_fraction: 0.9 +  pin_memory: true +  << : *mapping + +rotary_embedding: &rotary_embedding +  rotary_embedding:  +    _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding +    dim: 64 + +attn: &attn +  dim: &hidden_dim 512 +  num_heads: 4 +  dim_head: 64 +  dropout_rate: &dropout_rate 0.4 + +network: +  _target_: text_recognizer.networks.vq_transformer.VqTransformer +  input_dims: [1, 56, 1024] +  hidden_dim: *hidden_dim +  num_classes: *num_classes +  pad_index: *ignore_index +  encoder: +    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet +    arch: b1 +    stochastic_dropout_rate: 0.2 +    bn_momentum: 0.99 +    bn_eps: 1.0e-3 +  decoder: +    depth: 6 +    _target_: text_recognizer.networks.transformer.layers.Decoder +    self_attn: +      _target_: text_recognizer.networks.transformer.attention.Attention +      << : *attn +      causal: true +      << : *rotary_embedding +    cross_attn: +      _target_: text_recognizer.networks.transformer.attention.Attention +      << : *attn +      causal: false +    norm: +      _target_: text_recognizer.networks.transformer.norm.ScaleNorm +      normalized_shape: *hidden_dim +    ff:  +      _target_: text_recognizer.networks.transformer.mlp.FeedForward +      dim: *hidden_dim +      dim_out: null +      expansion_factor: 4 +      glu: true +      dropout_rate: *dropout_rate +    pre_norm: true +  pixel_pos_embedding: +    _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding +    dim: *hidden_dim +    shape: [1, 32] +  quantizer: +    _target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer +    input_dim: 512 +    codebook: +      _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook +      dim: 16 +      codebook_size: 4096 +      kmeans_init: true +      kmeans_iters: 10 +      decay: 0.8 +      eps: 1.0e-5 +      threshold_dead: 2 +    commitment: 1.0 + +model: +  _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel +  << : *mapping +  max_output_len: *max_output_len +  start_token: <s> +  end_token: <e> +  pad_token: <p> + +trainer: +  _target_: pytorch_lightning.Trainer +  stochastic_weight_avg: true +  auto_scale_batch_size: binsearch +  auto_lr_find: false +  gradient_clip_val: 0.5 +  fast_dev_run: false +  gpus: 1 +  precision: 16 +  max_epochs: *epochs +  terminate_on_nan: true +  weights_summary: null +  limit_train_batches: 1.0  +  limit_val_batches: 1.0 +  limit_test_batches: 1.0 +  resume_from_checkpoint: null +  accumulate_grad_batches: 1 +  overfit_batches: 0 diff --git a/training/conf/network/quantizer.yaml b/training/conf/network/quantizer.yaml new file mode 100644 index 0000000..827a247 --- /dev/null +++ b/training/conf/network/quantizer.yaml @@ -0,0 +1,12 @@ +_target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer +input_dim: 192 +codebook: +  _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook +  dim: 16 +  codebook_size: 2048 +  kmeans_init: true +  kmeans_iters: 10 +  decay: 0.8 +  eps: 1.0e-5 +  threshold_dead: 2 +commitment: 1.0  |