From bf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 27 Sep 2022 00:10:00 +0200 Subject: Refactor decoder block --- text_recognizer/networks/transformer/decoder.py | 77 ++-------------------- .../networks/transformer/decoder_block.py | 46 +++++++++++++ 2 files changed, 50 insertions(+), 73 deletions(-) create mode 100644 text_recognizer/networks/transformer/decoder_block.py diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index efa1e89..c7da226 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -5,88 +5,19 @@ from typing import Optional, Type from torch import Tensor, nn from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.decoder_block import DecoderBlock from text_recognizer.networks.transformer.ff import FeedForward -class DecoderBlock(nn.Module): - """Decoder block.""" - - def __init__( - self, - self_attn: Attention, - norm: Type[nn.Module], - ff: FeedForward, - cross_attn: Optional[Attention] = None, - ) -> None: - super().__init__() - self.layers = ("self_attn", "cross_attn", "ff") - self.has_pos_emb = self_attn.rotary_embedding is not None - self.blocks = self._build(self_attn, norm, ff, cross_attn) - - def _build( - self, - self_attn: Attention, - norm: Type[nn.Module], - ff: FeedForward, - cross_attn: Optional[Attention], - ) -> nn.ModuleDict: - return nn.ModuleDict( - { - self.layers[0]: nn.ModuleList([norm, self_attn]), - self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]), - self.layers[2]: nn.ModuleList([deepcopy(norm), ff]), - } - ) - - def _apply_block( - self, - layer: str, - x: Tensor, - context: Optional[Tensor] = None, - input_mask: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies block function.""" - residual = x - norm_fn, layer_fn = self.blocks[layer] - if layer == "self_attn": - out = layer_fn(x=x, input_mask=input_mask) - elif layer == "cross_attn": - out = layer_fn( - x=x, context=context, input_mask=input_mask, context_mask=context_mask - ) - else: - out = layer_fn(x) - out += residual - return norm_fn(out) - - def forward( - self, - x: Tensor, - context: Optional[Tensor] = None, - input_mask: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies decoder block on input signals.""" - for layer in self.layers: - x = self._apply_block( - layer=layer, - x=x, - context=context, - input_mask=input_mask, - context_mask=context_mask, - ) - return x - - class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, block: DecoderBlock) -> None: + def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None: super().__init__() self.depth = depth self.has_pos_emb = block.has_pos_emb self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) + self.ln = nn.LayerNorm(dim) def forward( self, @@ -100,4 +31,4 @@ class Decoder(nn.Module): x = block( x=x, context=context, input_mask=input_mask, context_mask=context_mask ) - return x + return self.ln(x) diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py new file mode 100644 index 0000000..e6e7fb8 --- /dev/null +++ b/text_recognizer/networks/transformer/decoder_block.py @@ -0,0 +1,46 @@ +"""Transformer decoder module.""" +from copy import deepcopy +from typing import Optional, Type + +from torch import Tensor, nn + +from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.ff import FeedForward + + +class DecoderBlock(nn.Module): + """Decoder block.""" + + def __init__( + self, + self_attn: Attention, + norm: Type[nn.Module], + ff: FeedForward, + cross_attn: Optional[Attention] = None, + ) -> None: + super().__init__() + self.ln_attn = norm + self.attn = self_attn + self.ln_cross_attn = deepcopy(norm) + self.cross_attn = cross_attn + self.ln_ff = deepcopy(norm) + self.ff = ff + self.has_pos_emb = self.attn.rotary_embedding is not None + + def forward( + self, + x: Tensor, + context: Optional[Tensor] = None, + input_mask: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + ) -> Tensor: + """Applies decoder block on input signals.""" + x = x + self.attn(self.ln_attn(x), input_mask=input_mask) + x += self.cross_attn( + x=self.ln_cross_attn(x), + context=context, + input_mask=input_mask, + context_mask=context_mask, + ) + x += self.ff(self.ln_ff(x)) + return x -- cgit v1.2.3-70-g09d2