diff options
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 9 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 18 |
2 files changed, 21 insertions, 6 deletions
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py index e2dce27..09aad55 100644 --- a/text_recognizer/networks/conformer/conformer.py +++ b/text_recognizer/networks/conformer/conformer.py @@ -11,6 +11,7 @@ class Conformer(nn.Module): def __init__( self, dim: int, + dim_gru: int, num_classes: int, subsampler: Type[nn.Module], block: ConformerBlock, @@ -19,10 +20,16 @@ class Conformer(nn.Module): super().__init__() self.subsampler = subsampler self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) - self.fc = nn.Linear(dim, num_classes, bias=False) + self.gru = nn.GRU( + dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False + ) + self.fc = nn.Linear(dim_gru, num_classes) def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) + B, T, C = x.shape for fn in self.blocks: x = fn(x) + x, _ = self.gru(x) + x = x.view(B, T, 2, -1).sum(2) return self.fc(x) diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py index 53928f1..42a983e 100644 --- a/text_recognizer/networks/conformer/subsampler.py +++ b/text_recognizer/networks/conformer/subsampler.py @@ -1,6 +1,7 @@ """Simple convolutional network.""" from typing import Tuple +from einops import rearrange from torch import nn, Tensor from text_recognizer.networks.transformer import ( @@ -12,16 +13,20 @@ class Subsampler(nn.Module): def __init__( self, channels: int, + dim: int, depth: int, + height: int, pixel_pos_embedding: AxialPositionalEmbedding, dropout: float = 0.1, ) -> None: super().__init__() self.pixel_pos_embedding = pixel_pos_embedding - self.subsampler, self.projector = self._build(channels, depth, dropout) + self.subsampler, self.projector = self._build( + channels, height, dim, depth, dropout + ) def _build( - self, channels: int, depth: int, dropout: float + self, channels: int, height: int, dim: int, depth: int, dropout: float ) -> Tuple[nn.Sequential, nn.Sequential]: subsampler = [] for i in range(depth): @@ -34,11 +39,14 @@ class Subsampler(nn.Module): ) ) subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(dropout)) + projector = nn.Sequential( + nn.Linear(channels * height, dim), nn.Dropout(dropout) + ) return nn.Sequential(*subsampler), projector def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) x = self.pixel_pos_embedding(x) - x = x.flatten(start_dim=2).permute(0, 2, 1) - return self.projector(x) + x = rearrange(x, "b c h w -> b w (c h)") + x = self.projector(x) + return x |