diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-02 23:43:37 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-02 23:43:37 +0200 |
commit | 0982e09066a3c31cb8b2fc32b5ecbc2bb64952fb (patch) | |
tree | 79f2abda178415d7862ce76e65e66bb2de76be7c /text_recognizer/networks/conformer | |
parent | 4cdc50e2c89015f49973eadddfbee88ba2744f06 (diff) |
Add conformer block
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r-- | text_recognizer/networks/conformer/block.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py index e69de29..d9782e8 100644 --- a/text_recognizer/networks/conformer/block.py +++ b/text_recognizer/networks/conformer/block.py @@ -0,0 +1,34 @@ +"""Conformer block.""" +from copy import deepcopy +from typing import Optional + +from torch import nn, Tensor +from text_recognizer.networks.conformer.conv import ConformerConv + +from text_recognizer.networks.conformer.mlp import MLP +from text_recognizer.networks.conformer.scale import Scale +from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.norm import PreNorm + + +class ConformerBlock(nn.Module): + def __init__( + self, + dim: int, + ff: MLP, + attn: Attention, + conv: ConformerConv, + ) -> None: + super().__init__() + self.attn = PreNorm(dim, attn) + self.ff_1 = Scale(0.5, ff) + self.ff_2 = deepcopy(self.ff_1) + self.conv = conv + self.post_norm = nn.LayerNorm(dim) + + def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + x = self.ff_1(x) + x + x = self.attn(x, mask=mask) + x + x = self.conv(x) + x + x = self.ff_2(x) + x + return self.post_norm(x) |