diff options
-rw-r--r-- | text_recognizer/models/conformer.py | 124 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/__init__.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/attention.py | 49 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/block.py | 34 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 35 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conv.py | 40 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/ff.py | 19 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/glu.py | 12 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/scale.py | 13 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 50 | ||||
-rw-r--r-- | training/conf/experiment/conformer_lines.yaml | 88 | ||||
-rw-r--r-- | training/conf/network/conformer.yaml | 36 |
12 files changed, 0 insertions, 507 deletions
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py deleted file mode 100644 index 41a9e4d..0000000 --- a/text_recognizer/models/conformer.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Lightning Conformer model.""" -import itertools -from typing import Optional, Tuple, Type - -from omegaconf import DictConfig -import torch -from torch import nn, Tensor - -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.models.base import LitBase -from text_recognizer.models.metrics import CharacterErrorRate -from text_recognizer.models.util import first_element - - -class LitConformer(LitBase): - """A PyTorch Lightning model for transformer networks.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_configs: DictConfig, - lr_scheduler_configs: Optional[DictConfig], - mapping: EmnistMapping, - max_output_len: int = 451, - start_token: str = "<s>", - end_token: str = "<e>", - pad_token: str = "<p>", - blank_token: str = "<b>", - ) -> None: - super().__init__( - network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping - ) - self.max_output_len = max_output_len - self.start_token = start_token - self.end_token = end_token - self.pad_token = pad_token - self.blank_token = blank_token - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) - self.blank_index = int(self.mapping.get_index(self.blank_token)) - self.ignore_indices = set( - [self.start_index, self.end_index, self.pad_index, self.blank_index] - ) - self.val_cer = CharacterErrorRate(self.ignore_indices) - self.test_cer = CharacterErrorRate(self.ignore_indices) - - @torch.no_grad() - def predict(self, x: Tensor) -> str: - """Predicts a sequence of characters.""" - logits = self(x) - logprobs = torch.log_softmax(logits, dim=1) - return self.decode(logprobs, self.max_output_len) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits = self(data) - logprobs = torch.log_softmax(logits, dim=1) - B, S, _ = logprobs.shape - input_length = torch.ones(B).type_as(logprobs).int() * S - target_length = first_element(targets, self.pad_index).type_as(targets) - loss = self.loss_fn( - logprobs.permute(1, 0, 2), targets, input_length, target_length - ) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - logits = self(data) - logprobs = torch.log_softmax(logits, dim=1) - B, S, _ = logprobs.shape - input_length = torch.ones(B).type_as(logprobs).int() * S - target_length = first_element(targets, self.pad_index).type_as(targets) - loss = self.loss_fn( - logprobs.permute(1, 0, 2), targets, input_length, target_length - ) - self.log("val/loss", loss) - preds = self.decode(logprobs, targets.shape[1]) - self.val_acc(preds, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - logits = self(data) - logprobs = torch.log_softmax(logits, dim=1) - preds = self.decode(logprobs, targets.shape[1]) - self.val_acc(preds, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - - def decode(self, logprobs: Tensor, max_length: int) -> Tensor: - """Greedly decodes a log prob sequence. - - Args: - logprobs (Tensor): Log probabilities. - max_length (int): Max length of a sequence. - - Shapes: - - x: :math: `(B, T, C)` - - output: :math: `(B, T)` - - Returns: - Tensor: A predicted sequence of characters. - """ - B = logprobs.shape[0] - argmax = logprobs.argmax(2) - decoded = torch.ones((B, max_length)).type_as(logprobs).int() * self.pad_index - for i in range(B): - seq = [ - b - for b, _ in itertools.groupby(argmax[i].tolist()) - if b != self.blank_index - ][:max_length] - for j, c in enumerate(seq): - decoded[i, j] = c - return decoded diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py deleted file mode 100644 index 8951481..0000000 --- a/text_recognizer/networks/conformer/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from text_recognizer.networks.conformer.block import ConformerBlock -from text_recognizer.networks.conformer.ff import Feedforward -from text_recognizer.networks.conformer.glu import GLU -from text_recognizer.networks.conformer.conformer import Conformer -from text_recognizer.networks.conformer.conv import ConformerConv -from text_recognizer.networks.conformer.subsampler import Subsampler -from text_recognizer.networks.conformer.attention import Attention diff --git a/text_recognizer/networks/conformer/attention.py b/text_recognizer/networks/conformer/attention.py deleted file mode 100644 index e56e572..0000000 --- a/text_recognizer/networks/conformer/attention.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Efficient self attention.""" -from einops import rearrange -import torch -import torch.nn.functional as F -from torch import einsum, nn, Tensor - - -class LayerNorm(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.gamma = nn.Parameter(torch.ones(dim)) - self.register_buffer("beta", torch.zeros(dim)) - - def forward(self, x: Tensor) -> Tensor: - return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) - - -class SwiGLU(nn.Module): - def forward(self, x: Tensor) -> Tensor: - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - -class Attention(nn.Module): - def __init__( - self, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4 - ) -> None: - super().__init__() - self.norm = LayerNorm(dim) - attn_inner_dim = heads * dim_head - ff_inner_dim = mult * dim - self.heads = heads - self.scale = dim_head ** -0.5 - self.fused_dims = (attn_inner_dim, dim_head, dim_head, (2 * ff_inner_dim)) - self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) - self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) - self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) - - def forward(self, x: Tensor) -> Tensor: - h = self.heads - x = self.norm(x) - q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) - q = rearrange(q, "b n (h d) -> b h n d", h=h) - q = q * self.scale - sim = einsum("b h i d, b j d -> b h i j", q, k) - attn = sim.softmax(dim=-1) - out = einsum("b h i j, b j d -> b h i d", attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - return self.attn_out(out) + self.ff_out(ff) diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py deleted file mode 100644 index c53f339..0000000 --- a/text_recognizer/networks/conformer/block.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Conformer block.""" -from copy import deepcopy -from typing import Optional - -from torch import nn, Tensor -from text_recognizer.networks.conformer.conv import ConformerConv - -from text_recognizer.networks.conformer.ff import Feedforward -from text_recognizer.networks.conformer.scale import Scale -from text_recognizer.networks.transformer.attention import Attention -from text_recognizer.networks.transformer.norm import PreNorm - - -class ConformerBlock(nn.Module): - def __init__( - self, - dim: int, - ff: Feedforward, - attn: Attention, - conv: ConformerConv, - ) -> None: - super().__init__() - self.attn = PreNorm(dim, attn) - self.ff_1 = Scale(0.5, ff) - self.ff_2 = deepcopy(self.ff_1) - self.conv = conv - self.post_norm = nn.LayerNorm(dim) - - def forward(self, x: Tensor) -> Tensor: - x = self.ff_1(x) + x - x = self.attn(x) + x - x = self.conv(x) + x - x = self.ff_2(x) + x - return self.post_norm(x) diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py deleted file mode 100644 index 09aad55..0000000 --- a/text_recognizer/networks/conformer/conformer.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Conformer module.""" -from copy import deepcopy -from typing import Type - -from torch import nn, Tensor - -from text_recognizer.networks.conformer.block import ConformerBlock - - -class Conformer(nn.Module): - def __init__( - self, - dim: int, - dim_gru: int, - num_classes: int, - subsampler: Type[nn.Module], - block: ConformerBlock, - depth: int, - ) -> None: - super().__init__() - self.subsampler = subsampler - self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) - self.gru = nn.GRU( - dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False - ) - self.fc = nn.Linear(dim_gru, num_classes) - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - B, T, C = x.shape - for fn in self.blocks: - x = fn(x) - x, _ = self.gru(x) - x = x.view(B, T, 2, -1).sum(2) - return self.fc(x) diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py deleted file mode 100644 index ac13f5d..0000000 --- a/text_recognizer/networks/conformer/conv.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Conformer convolutional block.""" -from einops import rearrange -from einops.layers.torch import Rearrange -from torch import nn, Tensor - - -from text_recognizer.networks.conformer.glu import GLU - - -class ConformerConv(nn.Module): - def __init__( - self, - dim: int, - expansion_factor: int = 2, - kernel_size: int = 31, - dropout: int = 0.0, - ) -> None: - super().__init__() - inner_dim = expansion_factor * dim - self.layers = nn.Sequential( - nn.LayerNorm(dim), - Rearrange("b n c -> b c n"), - nn.Conv1d(dim, 2 * inner_dim, 1), - GLU(dim=1), - nn.Conv1d( - in_channels=inner_dim, - out_channels=inner_dim, - kernel_size=kernel_size, - groups=inner_dim, - padding="same", - ), - nn.BatchNorm1d(inner_dim), - nn.Mish(inplace=True), - nn.Conv1d(inner_dim, dim, 1), - Rearrange("b c n -> b n c"), - nn.Dropout(dropout), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/text_recognizer/networks/conformer/ff.py b/text_recognizer/networks/conformer/ff.py deleted file mode 100644 index 2ef4245..0000000 --- a/text_recognizer/networks/conformer/ff.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Conformer feedforward block.""" -from torch import nn, Tensor - - -class Feedforward(nn.Module): - def __init__( - self, dim: int, expansion_factor: int = 4, dropout: float = 0.0 - ) -> None: - super().__init__() - self.layers = nn.Sequential( - nn.Linear(dim, expansion_factor * dim), - nn.Mish(inplace=True), - nn.Dropout(dropout), - nn.Linear(expansion_factor * dim, dim), - nn.Dropout(dropout), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/text_recognizer/networks/conformer/glu.py b/text_recognizer/networks/conformer/glu.py deleted file mode 100644 index 1a7c201..0000000 --- a/text_recognizer/networks/conformer/glu.py +++ /dev/null @@ -1,12 +0,0 @@ -"""GLU layer.""" -from torch import nn, Tensor - - -class GLU(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.dim = dim - - def forward(self, x: Tensor) -> Tensor: - out, gate = x.chunk(2, dim=self.dim) - return out * gate.sigmoid() diff --git a/text_recognizer/networks/conformer/scale.py b/text_recognizer/networks/conformer/scale.py deleted file mode 100644 index d012b81..0000000 --- a/text_recognizer/networks/conformer/scale.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Scale layer.""" -from typing import Dict -from torch import nn, Tensor - - -class Scale(nn.Module): - def __init__(self, scale: float, fn: nn.Module) -> None: - super().__init__() - self.scale = scale - self.fn = fn - - def forward(self, x: Tensor, **kwargs) -> Tensor: - return self.fn(x, **kwargs) * self.scale diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py deleted file mode 100644 index 133b53a..0000000 --- a/text_recognizer/networks/conformer/subsampler.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Simple convolutional network.""" -from typing import Tuple - -from einops import rearrange -from torch import nn, Tensor - -from text_recognizer.networks.transformer import AxialPositionalEmbedding - - -class Subsampler(nn.Module): - def __init__( - self, - channels: int, - dim: int, - depth: int, - height: int, - pixel_pos_embedding: AxialPositionalEmbedding, - dropout: float = 0.1, - ) -> None: - super().__init__() - self.pixel_pos_embedding = pixel_pos_embedding - self.subsampler, self.projector = self._build( - channels, height, dim, depth, dropout - ) - - def _build( - self, channels: int, height: int, dim: int, depth: int, dropout: float - ) -> Tuple[nn.Sequential, nn.Sequential]: - subsampler = [] - for i in range(depth): - subsampler.append( - nn.Conv2d( - in_channels=1 if i == 0 else channels, - out_channels=channels, - kernel_size=3, - stride=2, - ) - ) - subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential( - nn.Linear(channels * height, dim), nn.Dropout(dropout) - ) - return nn.Sequential(*subsampler), projector - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - x = self.pixel_pos_embedding(x) - x = rearrange(x, "b c h w -> b w (c h)") - x = self.projector(x) - return x diff --git a/training/conf/experiment/conformer_lines.yaml b/training/conf/experiment/conformer_lines.yaml deleted file mode 100644 index 06e761e..0000000 --- a/training/conf/experiment/conformer_lines.yaml +++ /dev/null @@ -1,88 +0,0 @@ -# @package _global_ - -defaults: - - override /mapping: null - - override /criterion: ctc - - override /callbacks: htr - - override /datamodule: iam_lines - - override /network: conformer - - override /model: null - - override /lr_schedulers: null - - override /optimizers: null - -epochs: &epochs 999 -num_classes: &num_classes 57 -max_output_len: &max_output_len 89 -summary: [[1, 56, 1024]] - -mapping: &mapping - mapping: - _target_: text_recognizer.data.mappings.EmnistMapping - -callbacks: - stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 - annealing_epochs: 10 - annealing_strategy: cos - device: null - -optimizers: - radam: - _target_: torch.optim.RAdam - lr: 3.0e-4 - betas: [0.9, 0.999] - weight_decay: 0 - eps: 1.0e-8 - parameters: network - -lr_schedulers: - network: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.5 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 - verbose: false - interval: epoch - monitor: val/loss - -datamodule: - batch_size: 8 - num_workers: 12 - train_fraction: 0.9 - pin_memory: true - << : *mapping - -model: - _target_: text_recognizer.models.conformer.LitConformer - <<: *mapping - max_output_len: *max_output_len - start_token: <s> - end_token: <e> - pad_token: <p> - blank_token: <b> - -trainer: - _target_: pytorch_lightning.Trainer - stochastic_weight_avg: true - auto_scale_batch_size: binsearch - auto_lr_find: false - gradient_clip_val: 0.5 - fast_dev_run: false - gpus: 1 - precision: 16 - max_epochs: *epochs - terminate_on_nan: true - weights_summary: null - limit_train_batches: 1.0 - limit_val_batches: 1.0 - limit_test_batches: 1.0 - resume_from_checkpoint: null - accumulate_grad_batches: 1 - overfit_batches: 0 diff --git a/training/conf/network/conformer.yaml b/training/conf/network/conformer.yaml deleted file mode 100644 index 1d72dd5..0000000 --- a/training/conf/network/conformer.yaml +++ /dev/null @@ -1,36 +0,0 @@ -_target_: text_recognizer.networks.conformer.Conformer -depth: 8 -num_classes: 57 -dim: &dim 144 -dim_gru: 144 -block: - _target_: text_recognizer.networks.conformer.ConformerBlock - dim: *dim - attn: - _target_: text_recognizer.networks.conformer.Attention - dim: *dim - heads: 8 - dim_head: 64 - mult: 4 - ff: - _target_: text_recognizer.networks.conformer.Feedforward - dim: *dim - expansion_factor: 4 - dropout: 0.1 - conv: - _target_: text_recognizer.networks.conformer.ConformerConv - dim: *dim - expansion_factor: 2 - kernel_size: 31 - dropout: 0.1 -subsampler: - _target_: text_recognizer.networks.conformer.Subsampler - pixel_pos_embedding: - _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding - dim: 64 - shape: [6, 127] - channels: 64 - height: 6 - dim: *dim - depth: 3 - dropout: 0.1 |