diff options
| -rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 77 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/decoder_block.py | 46 | 
2 files changed, 50 insertions, 73 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index efa1e89..c7da226 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -5,88 +5,19 @@ from typing import Optional, Type  from torch import Tensor, nn  from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.decoder_block import DecoderBlock  from text_recognizer.networks.transformer.ff import FeedForward -class DecoderBlock(nn.Module): -    """Decoder block.""" - -    def __init__( -        self, -        self_attn: Attention, -        norm: Type[nn.Module], -        ff: FeedForward, -        cross_attn: Optional[Attention] = None, -    ) -> None: -        super().__init__() -        self.layers = ("self_attn", "cross_attn", "ff") -        self.has_pos_emb = self_attn.rotary_embedding is not None -        self.blocks = self._build(self_attn, norm, ff, cross_attn) - -    def _build( -        self, -        self_attn: Attention, -        norm: Type[nn.Module], -        ff: FeedForward, -        cross_attn: Optional[Attention], -    ) -> nn.ModuleDict: -        return nn.ModuleDict( -            { -                self.layers[0]: nn.ModuleList([norm, self_attn]), -                self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]), -                self.layers[2]: nn.ModuleList([deepcopy(norm), ff]), -            } -        ) - -    def _apply_block( -        self, -        layer: str, -        x: Tensor, -        context: Optional[Tensor] = None, -        input_mask: Optional[Tensor] = None, -        context_mask: Optional[Tensor] = None, -    ) -> Tensor: -        """Applies block function.""" -        residual = x -        norm_fn, layer_fn = self.blocks[layer] -        if layer == "self_attn": -            out = layer_fn(x=x, input_mask=input_mask) -        elif layer == "cross_attn": -            out = layer_fn( -                x=x, context=context, input_mask=input_mask, context_mask=context_mask -            ) -        else: -            out = layer_fn(x) -        out += residual -        return norm_fn(out) - -    def forward( -        self, -        x: Tensor, -        context: Optional[Tensor] = None, -        input_mask: Optional[Tensor] = None, -        context_mask: Optional[Tensor] = None, -    ) -> Tensor: -        """Applies decoder block on input signals.""" -        for layer in self.layers: -            x = self._apply_block( -                layer=layer, -                x=x, -                context=context, -                input_mask=input_mask, -                context_mask=context_mask, -            ) -        return x - -  class Decoder(nn.Module):      """Decoder Network.""" -    def __init__(self, depth: int, block: DecoderBlock) -> None: +    def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None:          super().__init__()          self.depth = depth          self.has_pos_emb = block.has_pos_emb          self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) +        self.ln = nn.LayerNorm(dim)      def forward(          self, @@ -100,4 +31,4 @@ class Decoder(nn.Module):              x = block(                  x=x, context=context, input_mask=input_mask, context_mask=context_mask              ) -        return x +        return self.ln(x) diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py new file mode 100644 index 0000000..e6e7fb8 --- /dev/null +++ b/text_recognizer/networks/transformer/decoder_block.py @@ -0,0 +1,46 @@ +"""Transformer decoder module.""" +from copy import deepcopy +from typing import Optional, Type + +from torch import Tensor, nn + +from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.ff import FeedForward + + +class DecoderBlock(nn.Module): +    """Decoder block.""" + +    def __init__( +        self, +        self_attn: Attention, +        norm: Type[nn.Module], +        ff: FeedForward, +        cross_attn: Optional[Attention] = None, +    ) -> None: +        super().__init__() +        self.ln_attn = norm +        self.attn = self_attn +        self.ln_cross_attn = deepcopy(norm) +        self.cross_attn = cross_attn +        self.ln_ff = deepcopy(norm) +        self.ff = ff +        self.has_pos_emb = self.attn.rotary_embedding is not None + +    def forward( +        self, +        x: Tensor, +        context: Optional[Tensor] = None, +        input_mask: Optional[Tensor] = None, +        context_mask: Optional[Tensor] = None, +    ) -> Tensor: +        """Applies decoder block on input signals.""" +        x = x + self.attn(self.ln_attn(x), input_mask=input_mask) +        x += self.cross_attn( +            x=self.ln_cross_attn(x), +            context=context, +            input_mask=input_mask, +            context_mask=context_mask, +        ) +        x += self.ff(self.ln_ff(x)) +        return x  |