diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-12 23:16:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-12 23:16:20 +0200 |
commit | 8bb76745e43c6b4967c8e91ebaf4c4295d0b8d0b (patch) | |
tree | 5ff05d9fea92f7e5bd313d8cdc9559ccbc89a97a /text_recognizer/networks/conformer | |
parent | 8fe4b36bf22281c84c4afee811b3435f3b50686d (diff) |
Remove conformer
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r-- | text_recognizer/networks/conformer/__init__.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/attention.py | 49 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/block.py | 34 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 35 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conv.py | 40 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/ff.py | 19 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/glu.py | 12 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/scale.py | 13 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 50 |
9 files changed, 0 insertions, 259 deletions
diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py deleted file mode 100644 index 8951481..0000000 --- a/text_recognizer/networks/conformer/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from text_recognizer.networks.conformer.block import ConformerBlock -from text_recognizer.networks.conformer.ff import Feedforward -from text_recognizer.networks.conformer.glu import GLU -from text_recognizer.networks.conformer.conformer import Conformer -from text_recognizer.networks.conformer.conv import ConformerConv -from text_recognizer.networks.conformer.subsampler import Subsampler -from text_recognizer.networks.conformer.attention import Attention diff --git a/text_recognizer/networks/conformer/attention.py b/text_recognizer/networks/conformer/attention.py deleted file mode 100644 index e56e572..0000000 --- a/text_recognizer/networks/conformer/attention.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Efficient self attention.""" -from einops import rearrange -import torch -import torch.nn.functional as F -from torch import einsum, nn, Tensor - - -class LayerNorm(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.gamma = nn.Parameter(torch.ones(dim)) - self.register_buffer("beta", torch.zeros(dim)) - - def forward(self, x: Tensor) -> Tensor: - return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) - - -class SwiGLU(nn.Module): - def forward(self, x: Tensor) -> Tensor: - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - -class Attention(nn.Module): - def __init__( - self, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4 - ) -> None: - super().__init__() - self.norm = LayerNorm(dim) - attn_inner_dim = heads * dim_head - ff_inner_dim = mult * dim - self.heads = heads - self.scale = dim_head ** -0.5 - self.fused_dims = (attn_inner_dim, dim_head, dim_head, (2 * ff_inner_dim)) - self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) - self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) - self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) - - def forward(self, x: Tensor) -> Tensor: - h = self.heads - x = self.norm(x) - q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) - q = rearrange(q, "b n (h d) -> b h n d", h=h) - q = q * self.scale - sim = einsum("b h i d, b j d -> b h i j", q, k) - attn = sim.softmax(dim=-1) - out = einsum("b h i j, b j d -> b h i d", attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - return self.attn_out(out) + self.ff_out(ff) diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py deleted file mode 100644 index c53f339..0000000 --- a/text_recognizer/networks/conformer/block.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Conformer block.""" -from copy import deepcopy -from typing import Optional - -from torch import nn, Tensor -from text_recognizer.networks.conformer.conv import ConformerConv - -from text_recognizer.networks.conformer.ff import Feedforward -from text_recognizer.networks.conformer.scale import Scale -from text_recognizer.networks.transformer.attention import Attention -from text_recognizer.networks.transformer.norm import PreNorm - - -class ConformerBlock(nn.Module): - def __init__( - self, - dim: int, - ff: Feedforward, - attn: Attention, - conv: ConformerConv, - ) -> None: - super().__init__() - self.attn = PreNorm(dim, attn) - self.ff_1 = Scale(0.5, ff) - self.ff_2 = deepcopy(self.ff_1) - self.conv = conv - self.post_norm = nn.LayerNorm(dim) - - def forward(self, x: Tensor) -> Tensor: - x = self.ff_1(x) + x - x = self.attn(x) + x - x = self.conv(x) + x - x = self.ff_2(x) + x - return self.post_norm(x) diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py deleted file mode 100644 index 09aad55..0000000 --- a/text_recognizer/networks/conformer/conformer.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Conformer module.""" -from copy import deepcopy -from typing import Type - -from torch import nn, Tensor - -from text_recognizer.networks.conformer.block import ConformerBlock - - -class Conformer(nn.Module): - def __init__( - self, - dim: int, - dim_gru: int, - num_classes: int, - subsampler: Type[nn.Module], - block: ConformerBlock, - depth: int, - ) -> None: - super().__init__() - self.subsampler = subsampler - self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) - self.gru = nn.GRU( - dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False - ) - self.fc = nn.Linear(dim_gru, num_classes) - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - B, T, C = x.shape - for fn in self.blocks: - x = fn(x) - x, _ = self.gru(x) - x = x.view(B, T, 2, -1).sum(2) - return self.fc(x) diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py deleted file mode 100644 index ac13f5d..0000000 --- a/text_recognizer/networks/conformer/conv.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Conformer convolutional block.""" -from einops import rearrange -from einops.layers.torch import Rearrange -from torch import nn, Tensor - - -from text_recognizer.networks.conformer.glu import GLU - - -class ConformerConv(nn.Module): - def __init__( - self, - dim: int, - expansion_factor: int = 2, - kernel_size: int = 31, - dropout: int = 0.0, - ) -> None: - super().__init__() - inner_dim = expansion_factor * dim - self.layers = nn.Sequential( - nn.LayerNorm(dim), - Rearrange("b n c -> b c n"), - nn.Conv1d(dim, 2 * inner_dim, 1), - GLU(dim=1), - nn.Conv1d( - in_channels=inner_dim, - out_channels=inner_dim, - kernel_size=kernel_size, - groups=inner_dim, - padding="same", - ), - nn.BatchNorm1d(inner_dim), - nn.Mish(inplace=True), - nn.Conv1d(inner_dim, dim, 1), - Rearrange("b c n -> b n c"), - nn.Dropout(dropout), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/text_recognizer/networks/conformer/ff.py b/text_recognizer/networks/conformer/ff.py deleted file mode 100644 index 2ef4245..0000000 --- a/text_recognizer/networks/conformer/ff.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Conformer feedforward block.""" -from torch import nn, Tensor - - -class Feedforward(nn.Module): - def __init__( - self, dim: int, expansion_factor: int = 4, dropout: float = 0.0 - ) -> None: - super().__init__() - self.layers = nn.Sequential( - nn.Linear(dim, expansion_factor * dim), - nn.Mish(inplace=True), - nn.Dropout(dropout), - nn.Linear(expansion_factor * dim, dim), - nn.Dropout(dropout), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/text_recognizer/networks/conformer/glu.py b/text_recognizer/networks/conformer/glu.py deleted file mode 100644 index 1a7c201..0000000 --- a/text_recognizer/networks/conformer/glu.py +++ /dev/null @@ -1,12 +0,0 @@ -"""GLU layer.""" -from torch import nn, Tensor - - -class GLU(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.dim = dim - - def forward(self, x: Tensor) -> Tensor: - out, gate = x.chunk(2, dim=self.dim) - return out * gate.sigmoid() diff --git a/text_recognizer/networks/conformer/scale.py b/text_recognizer/networks/conformer/scale.py deleted file mode 100644 index d012b81..0000000 --- a/text_recognizer/networks/conformer/scale.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Scale layer.""" -from typing import Dict -from torch import nn, Tensor - - -class Scale(nn.Module): - def __init__(self, scale: float, fn: nn.Module) -> None: - super().__init__() - self.scale = scale - self.fn = fn - - def forward(self, x: Tensor, **kwargs) -> Tensor: - return self.fn(x, **kwargs) * self.scale diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py deleted file mode 100644 index 133b53a..0000000 --- a/text_recognizer/networks/conformer/subsampler.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Simple convolutional network.""" -from typing import Tuple - -from einops import rearrange -from torch import nn, Tensor - -from text_recognizer.networks.transformer import AxialPositionalEmbedding - - -class Subsampler(nn.Module): - def __init__( - self, - channels: int, - dim: int, - depth: int, - height: int, - pixel_pos_embedding: AxialPositionalEmbedding, - dropout: float = 0.1, - ) -> None: - super().__init__() - self.pixel_pos_embedding = pixel_pos_embedding - self.subsampler, self.projector = self._build( - channels, height, dim, depth, dropout - ) - - def _build( - self, channels: int, height: int, dim: int, depth: int, dropout: float - ) -> Tuple[nn.Sequential, nn.Sequential]: - subsampler = [] - for i in range(depth): - subsampler.append( - nn.Conv2d( - in_channels=1 if i == 0 else channels, - out_channels=channels, - kernel_size=3, - stride=2, - ) - ) - subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential( - nn.Linear(channels * height, dim), nn.Dropout(dropout) - ) - return nn.Sequential(*subsampler), projector - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - x = self.pixel_pos_embedding(x) - x = rearrange(x, "b c h w -> b w (c h)") - x = self.projector(x) - return x |