summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r--text_recognizer/networks/transformer/decoder.py77
1 files changed, 4 insertions, 73 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index efa1e89..c7da226 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -5,88 +5,19 @@ from typing import Optional, Type
from torch import Tensor, nn
from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.decoder_block import DecoderBlock
from text_recognizer.networks.transformer.ff import FeedForward
-class DecoderBlock(nn.Module):
- """Decoder block."""
-
- def __init__(
- self,
- self_attn: Attention,
- norm: Type[nn.Module],
- ff: FeedForward,
- cross_attn: Optional[Attention] = None,
- ) -> None:
- super().__init__()
- self.layers = ("self_attn", "cross_attn", "ff")
- self.has_pos_emb = self_attn.rotary_embedding is not None
- self.blocks = self._build(self_attn, norm, ff, cross_attn)
-
- def _build(
- self,
- self_attn: Attention,
- norm: Type[nn.Module],
- ff: FeedForward,
- cross_attn: Optional[Attention],
- ) -> nn.ModuleDict:
- return nn.ModuleDict(
- {
- self.layers[0]: nn.ModuleList([norm, self_attn]),
- self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]),
- self.layers[2]: nn.ModuleList([deepcopy(norm), ff]),
- }
- )
-
- def _apply_block(
- self,
- layer: str,
- x: Tensor,
- context: Optional[Tensor] = None,
- input_mask: Optional[Tensor] = None,
- context_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Applies block function."""
- residual = x
- norm_fn, layer_fn = self.blocks[layer]
- if layer == "self_attn":
- out = layer_fn(x=x, input_mask=input_mask)
- elif layer == "cross_attn":
- out = layer_fn(
- x=x, context=context, input_mask=input_mask, context_mask=context_mask
- )
- else:
- out = layer_fn(x)
- out += residual
- return norm_fn(out)
-
- def forward(
- self,
- x: Tensor,
- context: Optional[Tensor] = None,
- input_mask: Optional[Tensor] = None,
- context_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Applies decoder block on input signals."""
- for layer in self.layers:
- x = self._apply_block(
- layer=layer,
- x=x,
- context=context,
- input_mask=input_mask,
- context_mask=context_mask,
- )
- return x
-
-
class Decoder(nn.Module):
"""Decoder Network."""
- def __init__(self, depth: int, block: DecoderBlock) -> None:
+ def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None:
super().__init__()
self.depth = depth
self.has_pos_emb = block.has_pos_emb
self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
+ self.ln = nn.LayerNorm(dim)
def forward(
self,
@@ -100,4 +31,4 @@ class Decoder(nn.Module):
x = block(
x=x, context=context, input_mask=input_mask, context_mask=context_mask
)
- return x
+ return self.ln(x)