diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-12 23:16:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-12 23:16:20 +0200 |
commit | 8bb76745e43c6b4967c8e91ebaf4c4295d0b8d0b (patch) | |
tree | 5ff05d9fea92f7e5bd313d8cdc9559ccbc89a97a /text_recognizer | |
parent | 8fe4b36bf22281c84c4afee811b3435f3b50686d (diff) |
Remove conformer
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/conformer.py | 124 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/__init__.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/attention.py | 49 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/block.py | 34 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 35 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conv.py | 40 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/ff.py | 19 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/glu.py | 12 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/scale.py | 13 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 50 |
10 files changed, 0 insertions, 383 deletions
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py deleted file mode 100644 index 41a9e4d..0000000 --- a/text_recognizer/models/conformer.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Lightning Conformer model.""" -import itertools -from typing import Optional, Tuple, Type - -from omegaconf import DictConfig -import torch -from torch import nn, Tensor - -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.models.base import LitBase -from text_recognizer.models.metrics import CharacterErrorRate -from text_recognizer.models.util import first_element - - -class LitConformer(LitBase): - """A PyTorch Lightning model for transformer networks.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_configs: DictConfig, - lr_scheduler_configs: Optional[DictConfig], - mapping: EmnistMapping, - max_output_len: int = 451, - start_token: str = "<s>", - end_token: str = "<e>", - pad_token: str = "<p>", - blank_token: str = "<b>", - ) -> None: - super().__init__( - network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping - ) - self.max_output_len = max_output_len - self.start_token = start_token - self.end_token = end_token - self.pad_token = pad_token - self.blank_token = blank_token - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) - self.blank_index = int(self.mapping.get_index(self.blank_token)) - self.ignore_indices = set( - [self.start_index, self.end_index, self.pad_index, self.blank_index] - ) - self.val_cer = CharacterErrorRate(self.ignore_indices) - self.test_cer = CharacterErrorRate(self.ignore_indices) - - @torch.no_grad() - def predict(self, x: Tensor) -> str: - """Predicts a sequence of characters.""" - logits = self(x) - logprobs = torch.log_softmax(logits, dim=1) - return self.decode(logprobs, self.max_output_len) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits = self(data) - logprobs = torch.log_softmax(logits, dim=1) - B, S, _ = logprobs.shape - input_length = torch.ones(B).type_as(logprobs).int() * S - target_length = first_element(targets, self.pad_index).type_as(targets) - loss = self.loss_fn( - logprobs.permute(1, 0, 2), targets, input_length, target_length - ) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - logits = self(data) - logprobs = torch.log_softmax(logits, dim=1) - B, S, _ = logprobs.shape - input_length = torch.ones(B).type_as(logprobs).int() * S - target_length = first_element(targets, self.pad_index).type_as(targets) - loss = self.loss_fn( - logprobs.permute(1, 0, 2), targets, input_length, target_length - ) - self.log("val/loss", loss) - preds = self.decode(logprobs, targets.shape[1]) - self.val_acc(preds, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - logits = self(data) - logprobs = torch.log_softmax(logits, dim=1) - preds = self.decode(logprobs, targets.shape[1]) - self.val_acc(preds, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - - def decode(self, logprobs: Tensor, max_length: int) -> Tensor: - """Greedly decodes a log prob sequence. - - Args: - logprobs (Tensor): Log probabilities. - max_length (int): Max length of a sequence. - - Shapes: - - x: :math: `(B, T, C)` - - output: :math: `(B, T)` - - Returns: - Tensor: A predicted sequence of characters. - """ - B = logprobs.shape[0] - argmax = logprobs.argmax(2) - decoded = torch.ones((B, max_length)).type_as(logprobs).int() * self.pad_index - for i in range(B): - seq = [ - b - for b, _ in itertools.groupby(argmax[i].tolist()) - if b != self.blank_index - ][:max_length] - for j, c in enumerate(seq): - decoded[i, j] = c - return decoded diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py deleted file mode 100644 index 8951481..0000000 --- a/text_recognizer/networks/conformer/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from text_recognizer.networks.conformer.block import ConformerBlock -from text_recognizer.networks.conformer.ff import Feedforward -from text_recognizer.networks.conformer.glu import GLU -from text_recognizer.networks.conformer.conformer import Conformer -from text_recognizer.networks.conformer.conv import ConformerConv -from text_recognizer.networks.conformer.subsampler import Subsampler -from text_recognizer.networks.conformer.attention import Attention diff --git a/text_recognizer/networks/conformer/attention.py b/text_recognizer/networks/conformer/attention.py deleted file mode 100644 index e56e572..0000000 --- a/text_recognizer/networks/conformer/attention.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Efficient self attention.""" -from einops import rearrange -import torch -import torch.nn.functional as F -from torch import einsum, nn, Tensor - - -class LayerNorm(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.gamma = nn.Parameter(torch.ones(dim)) - self.register_buffer("beta", torch.zeros(dim)) - - def forward(self, x: Tensor) -> Tensor: - return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) - - -class SwiGLU(nn.Module): - def forward(self, x: Tensor) -> Tensor: - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - -class Attention(nn.Module): - def __init__( - self, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4 - ) -> None: - super().__init__() - self.norm = LayerNorm(dim) - attn_inner_dim = heads * dim_head - ff_inner_dim = mult * dim - self.heads = heads - self.scale = dim_head ** -0.5 - self.fused_dims = (attn_inner_dim, dim_head, dim_head, (2 * ff_inner_dim)) - self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) - self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) - self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) - - def forward(self, x: Tensor) -> Tensor: - h = self.heads - x = self.norm(x) - q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) - q = rearrange(q, "b n (h d) -> b h n d", h=h) - q = q * self.scale - sim = einsum("b h i d, b j d -> b h i j", q, k) - attn = sim.softmax(dim=-1) - out = einsum("b h i j, b j d -> b h i d", attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - return self.attn_out(out) + self.ff_out(ff) diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py deleted file mode 100644 index c53f339..0000000 --- a/text_recognizer/networks/conformer/block.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Conformer block.""" -from copy import deepcopy -from typing import Optional - -from torch import nn, Tensor -from text_recognizer.networks.conformer.conv import ConformerConv - -from text_recognizer.networks.conformer.ff import Feedforward -from text_recognizer.networks.conformer.scale import Scale -from text_recognizer.networks.transformer.attention import Attention -from text_recognizer.networks.transformer.norm import PreNorm - - -class ConformerBlock(nn.Module): - def __init__( - self, - dim: int, - ff: Feedforward, - attn: Attention, - conv: ConformerConv, - ) -> None: - super().__init__() - self.attn = PreNorm(dim, attn) - self.ff_1 = Scale(0.5, ff) - self.ff_2 = deepcopy(self.ff_1) - self.conv = conv - self.post_norm = nn.LayerNorm(dim) - - def forward(self, x: Tensor) -> Tensor: - x = self.ff_1(x) + x - x = self.attn(x) + x - x = self.conv(x) + x - x = self.ff_2(x) + x - return self.post_norm(x) diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py deleted file mode 100644 index 09aad55..0000000 --- a/text_recognizer/networks/conformer/conformer.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Conformer module.""" -from copy import deepcopy -from typing import Type - -from torch import nn, Tensor - -from text_recognizer.networks.conformer.block import ConformerBlock - - -class Conformer(nn.Module): - def __init__( - self, - dim: int, - dim_gru: int, - num_classes: int, - subsampler: Type[nn.Module], - block: ConformerBlock, - depth: int, - ) -> None: - super().__init__() - self.subsampler = subsampler - self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) - self.gru = nn.GRU( - dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False - ) - self.fc = nn.Linear(dim_gru, num_classes) - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - B, T, C = x.shape - for fn in self.blocks: - x = fn(x) - x, _ = self.gru(x) - x = x.view(B, T, 2, -1).sum(2) - return self.fc(x) diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py deleted file mode 100644 index ac13f5d..0000000 --- a/text_recognizer/networks/conformer/conv.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Conformer convolutional block.""" -from einops import rearrange -from einops.layers.torch import Rearrange -from torch import nn, Tensor - - -from text_recognizer.networks.conformer.glu import GLU - - -class ConformerConv(nn.Module): - def __init__( - self, - dim: int, - expansion_factor: int = 2, - kernel_size: int = 31, - dropout: int = 0.0, - ) -> None: - super().__init__() - inner_dim = expansion_factor * dim - self.layers = nn.Sequential( - nn.LayerNorm(dim), - Rearrange("b n c -> b c n"), - nn.Conv1d(dim, 2 * inner_dim, 1), - GLU(dim=1), - nn.Conv1d( - in_channels=inner_dim, - out_channels=inner_dim, - kernel_size=kernel_size, - groups=inner_dim, - padding="same", - ), - nn.BatchNorm1d(inner_dim), - nn.Mish(inplace=True), - nn.Conv1d(inner_dim, dim, 1), - Rearrange("b c n -> b n c"), - nn.Dropout(dropout), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/text_recognizer/networks/conformer/ff.py b/text_recognizer/networks/conformer/ff.py deleted file mode 100644 index 2ef4245..0000000 --- a/text_recognizer/networks/conformer/ff.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Conformer feedforward block.""" -from torch import nn, Tensor - - -class Feedforward(nn.Module): - def __init__( - self, dim: int, expansion_factor: int = 4, dropout: float = 0.0 - ) -> None: - super().__init__() - self.layers = nn.Sequential( - nn.Linear(dim, expansion_factor * dim), - nn.Mish(inplace=True), - nn.Dropout(dropout), - nn.Linear(expansion_factor * dim, dim), - nn.Dropout(dropout), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/text_recognizer/networks/conformer/glu.py b/text_recognizer/networks/conformer/glu.py deleted file mode 100644 index 1a7c201..0000000 --- a/text_recognizer/networks/conformer/glu.py +++ /dev/null @@ -1,12 +0,0 @@ -"""GLU layer.""" -from torch import nn, Tensor - - -class GLU(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.dim = dim - - def forward(self, x: Tensor) -> Tensor: - out, gate = x.chunk(2, dim=self.dim) - return out * gate.sigmoid() diff --git a/text_recognizer/networks/conformer/scale.py b/text_recognizer/networks/conformer/scale.py deleted file mode 100644 index d012b81..0000000 --- a/text_recognizer/networks/conformer/scale.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Scale layer.""" -from typing import Dict -from torch import nn, Tensor - - -class Scale(nn.Module): - def __init__(self, scale: float, fn: nn.Module) -> None: - super().__init__() - self.scale = scale - self.fn = fn - - def forward(self, x: Tensor, **kwargs) -> Tensor: - return self.fn(x, **kwargs) * self.scale diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py deleted file mode 100644 index 133b53a..0000000 --- a/text_recognizer/networks/conformer/subsampler.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Simple convolutional network.""" -from typing import Tuple - -from einops import rearrange -from torch import nn, Tensor - -from text_recognizer.networks.transformer import AxialPositionalEmbedding - - -class Subsampler(nn.Module): - def __init__( - self, - channels: int, - dim: int, - depth: int, - height: int, - pixel_pos_embedding: AxialPositionalEmbedding, - dropout: float = 0.1, - ) -> None: - super().__init__() - self.pixel_pos_embedding = pixel_pos_embedding - self.subsampler, self.projector = self._build( - channels, height, dim, depth, dropout - ) - - def _build( - self, channels: int, height: int, dim: int, depth: int, dropout: float - ) -> Tuple[nn.Sequential, nn.Sequential]: - subsampler = [] - for i in range(depth): - subsampler.append( - nn.Conv2d( - in_channels=1 if i == 0 else channels, - out_channels=channels, - kernel_size=3, - stride=2, - ) - ) - subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential( - nn.Linear(channels * height, dim), nn.Dropout(dropout) - ) - return nn.Sequential(*subsampler), projector - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - x = self.pixel_pos_embedding(x) - x = rearrange(x, "b c h w -> b w (c h)") - x = self.projector(x) - return x |