diff options
Diffstat (limited to 'text_recognizer/networks/conformer/subsampler.py')
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py index 53928f1..42a983e 100644 --- a/text_recognizer/networks/conformer/subsampler.py +++ b/text_recognizer/networks/conformer/subsampler.py @@ -1,6 +1,7 @@ """Simple convolutional network.""" from typing import Tuple +from einops import rearrange from torch import nn, Tensor from text_recognizer.networks.transformer import ( @@ -12,16 +13,20 @@ class Subsampler(nn.Module): def __init__( self, channels: int, + dim: int, depth: int, + height: int, pixel_pos_embedding: AxialPositionalEmbedding, dropout: float = 0.1, ) -> None: super().__init__() self.pixel_pos_embedding = pixel_pos_embedding - self.subsampler, self.projector = self._build(channels, depth, dropout) + self.subsampler, self.projector = self._build( + channels, height, dim, depth, dropout + ) def _build( - self, channels: int, depth: int, dropout: float + self, channels: int, height: int, dim: int, depth: int, dropout: float ) -> Tuple[nn.Sequential, nn.Sequential]: subsampler = [] for i in range(depth): @@ -34,11 +39,14 @@ class Subsampler(nn.Module): ) ) subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(dropout)) + projector = nn.Sequential( + nn.Linear(channels * height, dim), nn.Dropout(dropout) + ) return nn.Sequential(*subsampler), projector def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) x = self.pixel_pos_embedding(x) - x = x.flatten(start_dim=2).permute(0, 2, 1) - return self.projector(x) + x = rearrange(x, "b c h w -> b w (c h)") + x = self.projector(x) + return x |