diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:00 +0200 |
commit | bf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 (patch) | |
tree | 5679e4d28673b8898769266f8a8d856a4d924b07 /text_recognizer/networks/transformer/decoder.py | |
parent | 096823d111117ac5efe954db7f3db26cccabda6c (diff) |
Refactor decoder block
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 77 |
1 files changed, 4 insertions, 73 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index efa1e89..c7da226 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -5,88 +5,19 @@ from typing import Optional, Type from torch import Tensor, nn from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.decoder_block import DecoderBlock from text_recognizer.networks.transformer.ff import FeedForward -class DecoderBlock(nn.Module): - """Decoder block.""" - - def __init__( - self, - self_attn: Attention, - norm: Type[nn.Module], - ff: FeedForward, - cross_attn: Optional[Attention] = None, - ) -> None: - super().__init__() - self.layers = ("self_attn", "cross_attn", "ff") - self.has_pos_emb = self_attn.rotary_embedding is not None - self.blocks = self._build(self_attn, norm, ff, cross_attn) - - def _build( - self, - self_attn: Attention, - norm: Type[nn.Module], - ff: FeedForward, - cross_attn: Optional[Attention], - ) -> nn.ModuleDict: - return nn.ModuleDict( - { - self.layers[0]: nn.ModuleList([norm, self_attn]), - self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]), - self.layers[2]: nn.ModuleList([deepcopy(norm), ff]), - } - ) - - def _apply_block( - self, - layer: str, - x: Tensor, - context: Optional[Tensor] = None, - input_mask: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies block function.""" - residual = x - norm_fn, layer_fn = self.blocks[layer] - if layer == "self_attn": - out = layer_fn(x=x, input_mask=input_mask) - elif layer == "cross_attn": - out = layer_fn( - x=x, context=context, input_mask=input_mask, context_mask=context_mask - ) - else: - out = layer_fn(x) - out += residual - return norm_fn(out) - - def forward( - self, - x: Tensor, - context: Optional[Tensor] = None, - input_mask: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies decoder block on input signals.""" - for layer in self.layers: - x = self._apply_block( - layer=layer, - x=x, - context=context, - input_mask=input_mask, - context_mask=context_mask, - ) - return x - - class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, block: DecoderBlock) -> None: + def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None: super().__init__() self.depth = depth self.has_pos_emb = block.has_pos_emb self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) + self.ln = nn.LayerNorm(dim) def forward( self, @@ -100,4 +31,4 @@ class Decoder(nn.Module): x = block( x=x, context=context, input_mask=input_mask, context_mask=context_mask ) - return x + return self.ln(x) |