diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index a58f7bd..1812e40 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -1,6 +1,6 @@ """Transformer decoder module.""" from copy import deepcopy -from typing import Optional, Tuple, Type +from typing import Optional, Type from torch import nn, Tensor @@ -19,8 +19,8 @@ class DecoderBlock(nn.Module): cross_attn: Optional[Attention] = None, ) -> None: super().__init__() - self._layers = ("self_attn", "cross_attn", "ff") - self._blocks = self._build(self_attn, norm, ff, cross_attn) + self.layers = ("self_attn", "cross_attn", "ff") + self.blocks = self._build(self_attn, norm, ff, cross_attn) def _build( self, @@ -37,7 +37,7 @@ class DecoderBlock(nn.Module): } ) - def _apply( + def _apply_block( self, layer: str, x: Tensor, @@ -45,8 +45,9 @@ class DecoderBlock(nn.Module): input_mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: + """Applies block function.""" residual = x - norm_fn, layer_fn = self._blocks[layer] + norm_fn, layer_fn = self.blocks[layer] if layer == "self_attn": out = layer_fn(x=x, input_mask=input_mask) elif layer == "cross_attn": @@ -66,8 +67,8 @@ class DecoderBlock(nn.Module): context_mask: Optional[Tensor] = None, ) -> Tensor: """Applies decoder block on input signals.""" - for layer in self._layers: - x = self._apply( + for layer in self.layers: + x = self._apply_block( layer=layer, x=x, context=context, @@ -77,13 +78,14 @@ class DecoderBlock(nn.Module): return x -class Decoder: +class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, block: DecoderBlock) -> None: + def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None: + super().__init__() self.depth = depth - self.has_pos_emb: bool = block.rotary_embedding is not None - self._block = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) + self.has_pos_emb = has_pos_emb + self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) def forward( self, @@ -93,7 +95,7 @@ class Decoder: context_mask: Optional[Tensor] = None, ) -> Tensor: """Applies the network to the signals.""" - for block in self._blocks: + for block in self.blocks: x = block( x=x, context=context, input_mask=input_mask, context_mask=context_mask ) |