summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/conformer.py124
-rw-r--r--text_recognizer/networks/conformer/__init__.py7
-rw-r--r--text_recognizer/networks/conformer/attention.py49
-rw-r--r--text_recognizer/networks/conformer/block.py34
-rw-r--r--text_recognizer/networks/conformer/conformer.py35
-rw-r--r--text_recognizer/networks/conformer/conv.py40
-rw-r--r--text_recognizer/networks/conformer/ff.py19
-rw-r--r--text_recognizer/networks/conformer/glu.py12
-rw-r--r--text_recognizer/networks/conformer/scale.py13
-rw-r--r--text_recognizer/networks/conformer/subsampler.py50
-rw-r--r--training/conf/experiment/conformer_lines.yaml88
-rw-r--r--training/conf/network/conformer.yaml36
12 files changed, 0 insertions, 507 deletions
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py
deleted file mode 100644
index 41a9e4d..0000000
--- a/text_recognizer/models/conformer.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""Lightning Conformer model."""
-import itertools
-from typing import Optional, Tuple, Type
-
-from omegaconf import DictConfig
-import torch
-from torch import nn, Tensor
-
-from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.models.base import LitBase
-from text_recognizer.models.metrics import CharacterErrorRate
-from text_recognizer.models.util import first_element
-
-
-class LitConformer(LitBase):
- """A PyTorch Lightning model for transformer networks."""
-
- def __init__(
- self,
- network: Type[nn.Module],
- loss_fn: Type[nn.Module],
- optimizer_configs: DictConfig,
- lr_scheduler_configs: Optional[DictConfig],
- mapping: EmnistMapping,
- max_output_len: int = 451,
- start_token: str = "<s>",
- end_token: str = "<e>",
- pad_token: str = "<p>",
- blank_token: str = "<b>",
- ) -> None:
- super().__init__(
- network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping
- )
- self.max_output_len = max_output_len
- self.start_token = start_token
- self.end_token = end_token
- self.pad_token = pad_token
- self.blank_token = blank_token
- self.start_index = int(self.mapping.get_index(self.start_token))
- self.end_index = int(self.mapping.get_index(self.end_token))
- self.pad_index = int(self.mapping.get_index(self.pad_token))
- self.blank_index = int(self.mapping.get_index(self.blank_token))
- self.ignore_indices = set(
- [self.start_index, self.end_index, self.pad_index, self.blank_index]
- )
- self.val_cer = CharacterErrorRate(self.ignore_indices)
- self.test_cer = CharacterErrorRate(self.ignore_indices)
-
- @torch.no_grad()
- def predict(self, x: Tensor) -> str:
- """Predicts a sequence of characters."""
- logits = self(x)
- logprobs = torch.log_softmax(logits, dim=1)
- return self.decode(logprobs, self.max_output_len)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, targets = batch
- logits = self(data)
- logprobs = torch.log_softmax(logits, dim=1)
- B, S, _ = logprobs.shape
- input_length = torch.ones(B).type_as(logprobs).int() * S
- target_length = first_element(targets, self.pad_index).type_as(targets)
- loss = self.loss_fn(
- logprobs.permute(1, 0, 2), targets, input_length, target_length
- )
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, targets = batch
- logits = self(data)
- logprobs = torch.log_softmax(logits, dim=1)
- B, S, _ = logprobs.shape
- input_length = torch.ones(B).type_as(logprobs).int() * S
- target_length = first_element(targets, self.pad_index).type_as(targets)
- loss = self.loss_fn(
- logprobs.permute(1, 0, 2), targets, input_length, target_length
- )
- self.log("val/loss", loss)
- preds = self.decode(logprobs, targets.shape[1])
- self.val_acc(preds, targets)
- self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
- self.val_cer(preds, targets)
- self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, targets = batch
- logits = self(data)
- logprobs = torch.log_softmax(logits, dim=1)
- preds = self.decode(logprobs, targets.shape[1])
- self.val_acc(preds, targets)
- self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
- self.val_cer(preds, targets)
- self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
-
- def decode(self, logprobs: Tensor, max_length: int) -> Tensor:
- """Greedly decodes a log prob sequence.
-
- Args:
- logprobs (Tensor): Log probabilities.
- max_length (int): Max length of a sequence.
-
- Shapes:
- - x: :math: `(B, T, C)`
- - output: :math: `(B, T)`
-
- Returns:
- Tensor: A predicted sequence of characters.
- """
- B = logprobs.shape[0]
- argmax = logprobs.argmax(2)
- decoded = torch.ones((B, max_length)).type_as(logprobs).int() * self.pad_index
- for i in range(B):
- seq = [
- b
- for b, _ in itertools.groupby(argmax[i].tolist())
- if b != self.blank_index
- ][:max_length]
- for j, c in enumerate(seq):
- decoded[i, j] = c
- return decoded
diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py
deleted file mode 100644
index 8951481..0000000
--- a/text_recognizer/networks/conformer/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from text_recognizer.networks.conformer.block import ConformerBlock
-from text_recognizer.networks.conformer.ff import Feedforward
-from text_recognizer.networks.conformer.glu import GLU
-from text_recognizer.networks.conformer.conformer import Conformer
-from text_recognizer.networks.conformer.conv import ConformerConv
-from text_recognizer.networks.conformer.subsampler import Subsampler
-from text_recognizer.networks.conformer.attention import Attention
diff --git a/text_recognizer/networks/conformer/attention.py b/text_recognizer/networks/conformer/attention.py
deleted file mode 100644
index e56e572..0000000
--- a/text_recognizer/networks/conformer/attention.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Efficient self attention."""
-from einops import rearrange
-import torch
-import torch.nn.functional as F
-from torch import einsum, nn, Tensor
-
-
-class LayerNorm(nn.Module):
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.gamma = nn.Parameter(torch.ones(dim))
- self.register_buffer("beta", torch.zeros(dim))
-
- def forward(self, x: Tensor) -> Tensor:
- return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
-
-
-class SwiGLU(nn.Module):
- def forward(self, x: Tensor) -> Tensor:
- x, gate = x.chunk(2, dim=-1)
- return F.silu(gate) * x
-
-
-class Attention(nn.Module):
- def __init__(
- self, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4
- ) -> None:
- super().__init__()
- self.norm = LayerNorm(dim)
- attn_inner_dim = heads * dim_head
- ff_inner_dim = mult * dim
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.fused_dims = (attn_inner_dim, dim_head, dim_head, (2 * ff_inner_dim))
- self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
- self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
- self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False))
-
- def forward(self, x: Tensor) -> Tensor:
- h = self.heads
- x = self.norm(x)
- q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
- q = rearrange(q, "b n (h d) -> b h n d", h=h)
- q = q * self.scale
- sim = einsum("b h i d, b j d -> b h i j", q, k)
- attn = sim.softmax(dim=-1)
- out = einsum("b h i j, b j d -> b h i d", attn, v)
- out = rearrange(out, "b h n d -> b n (h d)")
- return self.attn_out(out) + self.ff_out(ff)
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
deleted file mode 100644
index c53f339..0000000
--- a/text_recognizer/networks/conformer/block.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Conformer block."""
-from copy import deepcopy
-from typing import Optional
-
-from torch import nn, Tensor
-from text_recognizer.networks.conformer.conv import ConformerConv
-
-from text_recognizer.networks.conformer.ff import Feedforward
-from text_recognizer.networks.conformer.scale import Scale
-from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.norm import PreNorm
-
-
-class ConformerBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- ff: Feedforward,
- attn: Attention,
- conv: ConformerConv,
- ) -> None:
- super().__init__()
- self.attn = PreNorm(dim, attn)
- self.ff_1 = Scale(0.5, ff)
- self.ff_2 = deepcopy(self.ff_1)
- self.conv = conv
- self.post_norm = nn.LayerNorm(dim)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.ff_1(x) + x
- x = self.attn(x) + x
- x = self.conv(x) + x
- x = self.ff_2(x) + x
- return self.post_norm(x)
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py
deleted file mode 100644
index 09aad55..0000000
--- a/text_recognizer/networks/conformer/conformer.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""Conformer module."""
-from copy import deepcopy
-from typing import Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conformer.block import ConformerBlock
-
-
-class Conformer(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_gru: int,
- num_classes: int,
- subsampler: Type[nn.Module],
- block: ConformerBlock,
- depth: int,
- ) -> None:
- super().__init__()
- self.subsampler = subsampler
- self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)])
- self.gru = nn.GRU(
- dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False
- )
- self.fc = nn.Linear(dim_gru, num_classes)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.subsampler(x)
- B, T, C = x.shape
- for fn in self.blocks:
- x = fn(x)
- x, _ = self.gru(x)
- x = x.view(B, T, 2, -1).sum(2)
- return self.fc(x)
diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py
deleted file mode 100644
index ac13f5d..0000000
--- a/text_recognizer/networks/conformer/conv.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Conformer convolutional block."""
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from torch import nn, Tensor
-
-
-from text_recognizer.networks.conformer.glu import GLU
-
-
-class ConformerConv(nn.Module):
- def __init__(
- self,
- dim: int,
- expansion_factor: int = 2,
- kernel_size: int = 31,
- dropout: int = 0.0,
- ) -> None:
- super().__init__()
- inner_dim = expansion_factor * dim
- self.layers = nn.Sequential(
- nn.LayerNorm(dim),
- Rearrange("b n c -> b c n"),
- nn.Conv1d(dim, 2 * inner_dim, 1),
- GLU(dim=1),
- nn.Conv1d(
- in_channels=inner_dim,
- out_channels=inner_dim,
- kernel_size=kernel_size,
- groups=inner_dim,
- padding="same",
- ),
- nn.BatchNorm1d(inner_dim),
- nn.Mish(inplace=True),
- nn.Conv1d(inner_dim, dim, 1),
- Rearrange("b c n -> b n c"),
- nn.Dropout(dropout),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layers(x)
diff --git a/text_recognizer/networks/conformer/ff.py b/text_recognizer/networks/conformer/ff.py
deleted file mode 100644
index 2ef4245..0000000
--- a/text_recognizer/networks/conformer/ff.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Conformer feedforward block."""
-from torch import nn, Tensor
-
-
-class Feedforward(nn.Module):
- def __init__(
- self, dim: int, expansion_factor: int = 4, dropout: float = 0.0
- ) -> None:
- super().__init__()
- self.layers = nn.Sequential(
- nn.Linear(dim, expansion_factor * dim),
- nn.Mish(inplace=True),
- nn.Dropout(dropout),
- nn.Linear(expansion_factor * dim, dim),
- nn.Dropout(dropout),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layers(x)
diff --git a/text_recognizer/networks/conformer/glu.py b/text_recognizer/networks/conformer/glu.py
deleted file mode 100644
index 1a7c201..0000000
--- a/text_recognizer/networks/conformer/glu.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""GLU layer."""
-from torch import nn, Tensor
-
-
-class GLU(nn.Module):
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.dim = dim
-
- def forward(self, x: Tensor) -> Tensor:
- out, gate = x.chunk(2, dim=self.dim)
- return out * gate.sigmoid()
diff --git a/text_recognizer/networks/conformer/scale.py b/text_recognizer/networks/conformer/scale.py
deleted file mode 100644
index d012b81..0000000
--- a/text_recognizer/networks/conformer/scale.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Scale layer."""
-from typing import Dict
-from torch import nn, Tensor
-
-
-class Scale(nn.Module):
- def __init__(self, scale: float, fn: nn.Module) -> None:
- super().__init__()
- self.scale = scale
- self.fn = fn
-
- def forward(self, x: Tensor, **kwargs) -> Tensor:
- return self.fn(x, **kwargs) * self.scale
diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py
deleted file mode 100644
index 133b53a..0000000
--- a/text_recognizer/networks/conformer/subsampler.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Simple convolutional network."""
-from typing import Tuple
-
-from einops import rearrange
-from torch import nn, Tensor
-
-from text_recognizer.networks.transformer import AxialPositionalEmbedding
-
-
-class Subsampler(nn.Module):
- def __init__(
- self,
- channels: int,
- dim: int,
- depth: int,
- height: int,
- pixel_pos_embedding: AxialPositionalEmbedding,
- dropout: float = 0.1,
- ) -> None:
- super().__init__()
- self.pixel_pos_embedding = pixel_pos_embedding
- self.subsampler, self.projector = self._build(
- channels, height, dim, depth, dropout
- )
-
- def _build(
- self, channels: int, height: int, dim: int, depth: int, dropout: float
- ) -> Tuple[nn.Sequential, nn.Sequential]:
- subsampler = []
- for i in range(depth):
- subsampler.append(
- nn.Conv2d(
- in_channels=1 if i == 0 else channels,
- out_channels=channels,
- kernel_size=3,
- stride=2,
- )
- )
- subsampler.append(nn.Mish(inplace=True))
- projector = nn.Sequential(
- nn.Linear(channels * height, dim), nn.Dropout(dropout)
- )
- return nn.Sequential(*subsampler), projector
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.subsampler(x)
- x = self.pixel_pos_embedding(x)
- x = rearrange(x, "b c h w -> b w (c h)")
- x = self.projector(x)
- return x
diff --git a/training/conf/experiment/conformer_lines.yaml b/training/conf/experiment/conformer_lines.yaml
deleted file mode 100644
index 06e761e..0000000
--- a/training/conf/experiment/conformer_lines.yaml
+++ /dev/null
@@ -1,88 +0,0 @@
-# @package _global_
-
-defaults:
- - override /mapping: null
- - override /criterion: ctc
- - override /callbacks: htr
- - override /datamodule: iam_lines
- - override /network: conformer
- - override /model: null
- - override /lr_schedulers: null
- - override /optimizers: null
-
-epochs: &epochs 999
-num_classes: &num_classes 57
-max_output_len: &max_output_len 89
-summary: [[1, 56, 1024]]
-
-mapping: &mapping
- mapping:
- _target_: text_recognizer.data.mappings.EmnistMapping
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizers:
- radam:
- _target_: torch.optim.RAdam
- lr: 3.0e-4
- betas: [0.9, 0.999]
- weight_decay: 0
- eps: 1.0e-8
- parameters: network
-
-lr_schedulers:
- network:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.5
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
- verbose: false
- interval: epoch
- monitor: val/loss
-
-datamodule:
- batch_size: 8
- num_workers: 12
- train_fraction: 0.9
- pin_memory: true
- << : *mapping
-
-model:
- _target_: text_recognizer.models.conformer.LitConformer
- <<: *mapping
- max_output_len: *max_output_len
- start_token: <s>
- end_token: <e>
- pad_token: <p>
- blank_token: <b>
-
-trainer:
- _target_: pytorch_lightning.Trainer
- stochastic_weight_avg: true
- auto_scale_batch_size: binsearch
- auto_lr_find: false
- gradient_clip_val: 0.5
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: *epochs
- terminate_on_nan: true
- weights_summary: null
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
- resume_from_checkpoint: null
- accumulate_grad_batches: 1
- overfit_batches: 0
diff --git a/training/conf/network/conformer.yaml b/training/conf/network/conformer.yaml
deleted file mode 100644
index 1d72dd5..0000000
--- a/training/conf/network/conformer.yaml
+++ /dev/null
@@ -1,36 +0,0 @@
-_target_: text_recognizer.networks.conformer.Conformer
-depth: 8
-num_classes: 57
-dim: &dim 144
-dim_gru: 144
-block:
- _target_: text_recognizer.networks.conformer.ConformerBlock
- dim: *dim
- attn:
- _target_: text_recognizer.networks.conformer.Attention
- dim: *dim
- heads: 8
- dim_head: 64
- mult: 4
- ff:
- _target_: text_recognizer.networks.conformer.Feedforward
- dim: *dim
- expansion_factor: 4
- dropout: 0.1
- conv:
- _target_: text_recognizer.networks.conformer.ConformerConv
- dim: *dim
- expansion_factor: 2
- kernel_size: 31
- dropout: 0.1
-subsampler:
- _target_: text_recognizer.networks.conformer.Subsampler
- pixel_pos_embedding:
- _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
- dim: 64
- shape: [6, 127]
- channels: 64
- height: 6
- dim: *dim
- depth: 3
- dropout: 0.1