summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r--text_recognizer/networks/conformer/__init__.py7
-rw-r--r--text_recognizer/networks/conformer/attention.py49
-rw-r--r--text_recognizer/networks/conformer/block.py34
-rw-r--r--text_recognizer/networks/conformer/conformer.py35
-rw-r--r--text_recognizer/networks/conformer/conv.py40
-rw-r--r--text_recognizer/networks/conformer/ff.py19
-rw-r--r--text_recognizer/networks/conformer/glu.py12
-rw-r--r--text_recognizer/networks/conformer/scale.py13
-rw-r--r--text_recognizer/networks/conformer/subsampler.py50
9 files changed, 0 insertions, 259 deletions
diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py
deleted file mode 100644
index 8951481..0000000
--- a/text_recognizer/networks/conformer/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from text_recognizer.networks.conformer.block import ConformerBlock
-from text_recognizer.networks.conformer.ff import Feedforward
-from text_recognizer.networks.conformer.glu import GLU
-from text_recognizer.networks.conformer.conformer import Conformer
-from text_recognizer.networks.conformer.conv import ConformerConv
-from text_recognizer.networks.conformer.subsampler import Subsampler
-from text_recognizer.networks.conformer.attention import Attention
diff --git a/text_recognizer/networks/conformer/attention.py b/text_recognizer/networks/conformer/attention.py
deleted file mode 100644
index e56e572..0000000
--- a/text_recognizer/networks/conformer/attention.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Efficient self attention."""
-from einops import rearrange
-import torch
-import torch.nn.functional as F
-from torch import einsum, nn, Tensor
-
-
-class LayerNorm(nn.Module):
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.gamma = nn.Parameter(torch.ones(dim))
- self.register_buffer("beta", torch.zeros(dim))
-
- def forward(self, x: Tensor) -> Tensor:
- return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
-
-
-class SwiGLU(nn.Module):
- def forward(self, x: Tensor) -> Tensor:
- x, gate = x.chunk(2, dim=-1)
- return F.silu(gate) * x
-
-
-class Attention(nn.Module):
- def __init__(
- self, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4
- ) -> None:
- super().__init__()
- self.norm = LayerNorm(dim)
- attn_inner_dim = heads * dim_head
- ff_inner_dim = mult * dim
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.fused_dims = (attn_inner_dim, dim_head, dim_head, (2 * ff_inner_dim))
- self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
- self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
- self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False))
-
- def forward(self, x: Tensor) -> Tensor:
- h = self.heads
- x = self.norm(x)
- q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
- q = rearrange(q, "b n (h d) -> b h n d", h=h)
- q = q * self.scale
- sim = einsum("b h i d, b j d -> b h i j", q, k)
- attn = sim.softmax(dim=-1)
- out = einsum("b h i j, b j d -> b h i d", attn, v)
- out = rearrange(out, "b h n d -> b n (h d)")
- return self.attn_out(out) + self.ff_out(ff)
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
deleted file mode 100644
index c53f339..0000000
--- a/text_recognizer/networks/conformer/block.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Conformer block."""
-from copy import deepcopy
-from typing import Optional
-
-from torch import nn, Tensor
-from text_recognizer.networks.conformer.conv import ConformerConv
-
-from text_recognizer.networks.conformer.ff import Feedforward
-from text_recognizer.networks.conformer.scale import Scale
-from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.norm import PreNorm
-
-
-class ConformerBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- ff: Feedforward,
- attn: Attention,
- conv: ConformerConv,
- ) -> None:
- super().__init__()
- self.attn = PreNorm(dim, attn)
- self.ff_1 = Scale(0.5, ff)
- self.ff_2 = deepcopy(self.ff_1)
- self.conv = conv
- self.post_norm = nn.LayerNorm(dim)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.ff_1(x) + x
- x = self.attn(x) + x
- x = self.conv(x) + x
- x = self.ff_2(x) + x
- return self.post_norm(x)
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py
deleted file mode 100644
index 09aad55..0000000
--- a/text_recognizer/networks/conformer/conformer.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""Conformer module."""
-from copy import deepcopy
-from typing import Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conformer.block import ConformerBlock
-
-
-class Conformer(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_gru: int,
- num_classes: int,
- subsampler: Type[nn.Module],
- block: ConformerBlock,
- depth: int,
- ) -> None:
- super().__init__()
- self.subsampler = subsampler
- self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)])
- self.gru = nn.GRU(
- dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False
- )
- self.fc = nn.Linear(dim_gru, num_classes)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.subsampler(x)
- B, T, C = x.shape
- for fn in self.blocks:
- x = fn(x)
- x, _ = self.gru(x)
- x = x.view(B, T, 2, -1).sum(2)
- return self.fc(x)
diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py
deleted file mode 100644
index ac13f5d..0000000
--- a/text_recognizer/networks/conformer/conv.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Conformer convolutional block."""
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from torch import nn, Tensor
-
-
-from text_recognizer.networks.conformer.glu import GLU
-
-
-class ConformerConv(nn.Module):
- def __init__(
- self,
- dim: int,
- expansion_factor: int = 2,
- kernel_size: int = 31,
- dropout: int = 0.0,
- ) -> None:
- super().__init__()
- inner_dim = expansion_factor * dim
- self.layers = nn.Sequential(
- nn.LayerNorm(dim),
- Rearrange("b n c -> b c n"),
- nn.Conv1d(dim, 2 * inner_dim, 1),
- GLU(dim=1),
- nn.Conv1d(
- in_channels=inner_dim,
- out_channels=inner_dim,
- kernel_size=kernel_size,
- groups=inner_dim,
- padding="same",
- ),
- nn.BatchNorm1d(inner_dim),
- nn.Mish(inplace=True),
- nn.Conv1d(inner_dim, dim, 1),
- Rearrange("b c n -> b n c"),
- nn.Dropout(dropout),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layers(x)
diff --git a/text_recognizer/networks/conformer/ff.py b/text_recognizer/networks/conformer/ff.py
deleted file mode 100644
index 2ef4245..0000000
--- a/text_recognizer/networks/conformer/ff.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Conformer feedforward block."""
-from torch import nn, Tensor
-
-
-class Feedforward(nn.Module):
- def __init__(
- self, dim: int, expansion_factor: int = 4, dropout: float = 0.0
- ) -> None:
- super().__init__()
- self.layers = nn.Sequential(
- nn.Linear(dim, expansion_factor * dim),
- nn.Mish(inplace=True),
- nn.Dropout(dropout),
- nn.Linear(expansion_factor * dim, dim),
- nn.Dropout(dropout),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layers(x)
diff --git a/text_recognizer/networks/conformer/glu.py b/text_recognizer/networks/conformer/glu.py
deleted file mode 100644
index 1a7c201..0000000
--- a/text_recognizer/networks/conformer/glu.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""GLU layer."""
-from torch import nn, Tensor
-
-
-class GLU(nn.Module):
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.dim = dim
-
- def forward(self, x: Tensor) -> Tensor:
- out, gate = x.chunk(2, dim=self.dim)
- return out * gate.sigmoid()
diff --git a/text_recognizer/networks/conformer/scale.py b/text_recognizer/networks/conformer/scale.py
deleted file mode 100644
index d012b81..0000000
--- a/text_recognizer/networks/conformer/scale.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Scale layer."""
-from typing import Dict
-from torch import nn, Tensor
-
-
-class Scale(nn.Module):
- def __init__(self, scale: float, fn: nn.Module) -> None:
- super().__init__()
- self.scale = scale
- self.fn = fn
-
- def forward(self, x: Tensor, **kwargs) -> Tensor:
- return self.fn(x, **kwargs) * self.scale
diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py
deleted file mode 100644
index 133b53a..0000000
--- a/text_recognizer/networks/conformer/subsampler.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Simple convolutional network."""
-from typing import Tuple
-
-from einops import rearrange
-from torch import nn, Tensor
-
-from text_recognizer.networks.transformer import AxialPositionalEmbedding
-
-
-class Subsampler(nn.Module):
- def __init__(
- self,
- channels: int,
- dim: int,
- depth: int,
- height: int,
- pixel_pos_embedding: AxialPositionalEmbedding,
- dropout: float = 0.1,
- ) -> None:
- super().__init__()
- self.pixel_pos_embedding = pixel_pos_embedding
- self.subsampler, self.projector = self._build(
- channels, height, dim, depth, dropout
- )
-
- def _build(
- self, channels: int, height: int, dim: int, depth: int, dropout: float
- ) -> Tuple[nn.Sequential, nn.Sequential]:
- subsampler = []
- for i in range(depth):
- subsampler.append(
- nn.Conv2d(
- in_channels=1 if i == 0 else channels,
- out_channels=channels,
- kernel_size=3,
- stride=2,
- )
- )
- subsampler.append(nn.Mish(inplace=True))
- projector = nn.Sequential(
- nn.Linear(channels * height, dim), nn.Dropout(dropout)
- )
- return nn.Sequential(*subsampler), projector
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.subsampler(x)
- x = self.pixel_pos_embedding(x)
- x = rearrange(x, "b c h w -> b w (c h)")
- x = self.projector(x)
- return x