From 7b660c13ce3c0edeace1107838e62c559bc6f078 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 8 Jun 2022 08:41:09 +0200 Subject: Fix conformer net --- text_recognizer/networks/conformer/block.py | 4 ++-- text_recognizer/networks/conformer/conformer.py | 5 ++++- text_recognizer/networks/conformer/subsampler.py | 8 +++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py index 4ea33c0..c53f339 100644 --- a/text_recognizer/networks/conformer/block.py +++ b/text_recognizer/networks/conformer/block.py @@ -26,9 +26,9 @@ class ConformerBlock(nn.Module): self.conv = conv self.post_norm = nn.LayerNorm(dim) - def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.ff_1(x) + x - x = self.attn(x, input_mask=mask) + x + x = self.attn(x) + x x = self.conv(x) + x x = self.ff_2(x) + x return self.post_norm(x) diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py index 8d0e98e..e2dce27 100644 --- a/text_recognizer/networks/conformer/conformer.py +++ b/text_recognizer/networks/conformer/conformer.py @@ -10,6 +10,8 @@ from text_recognizer.networks.conformer.block import ConformerBlock class Conformer(nn.Module): def __init__( self, + dim: int, + num_classes: int, subsampler: Type[nn.Module], block: ConformerBlock, depth: int, @@ -17,9 +19,10 @@ class Conformer(nn.Module): super().__init__() self.subsampler = subsampler self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) + self.fc = nn.Linear(dim, num_classes, bias=False) def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) for fn in self.blocks: x = fn(x) - return x + return self.fc(x) diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py index 2bc0445..53928f1 100644 --- a/text_recognizer/networks/conformer/subsampler.py +++ b/text_recognizer/networks/conformer/subsampler.py @@ -34,13 +34,11 @@ class Subsampler(nn.Module): ) ) subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential( - nn.Flatten(start_dim=2), nn.Linear(channels, channels), nn.Dropout(dropout) - ) + projector = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(dropout)) return nn.Sequential(*subsampler), projector def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) x = self.pixel_pos_embedding(x) - x = self.projector(x) - return x.permute(0, 2, 1) + x = x.flatten(start_dim=2).permute(0, 2, 1) + return self.projector(x) -- cgit v1.2.3-70-g09d2