summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:10:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:10:00 +0200
commitbf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 (patch)
tree5679e4d28673b8898769266f8a8d856a4d924b07 /text_recognizer/networks
parent096823d111117ac5efe954db7f3db26cccabda6c (diff)
Refactor decoder block
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/decoder.py77
-rw-r--r--text_recognizer/networks/transformer/decoder_block.py46
2 files changed, 50 insertions, 73 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index efa1e89..c7da226 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -5,88 +5,19 @@ from typing import Optional, Type
from torch import Tensor, nn
from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.decoder_block import DecoderBlock
from text_recognizer.networks.transformer.ff import FeedForward
-class DecoderBlock(nn.Module):
- """Decoder block."""
-
- def __init__(
- self,
- self_attn: Attention,
- norm: Type[nn.Module],
- ff: FeedForward,
- cross_attn: Optional[Attention] = None,
- ) -> None:
- super().__init__()
- self.layers = ("self_attn", "cross_attn", "ff")
- self.has_pos_emb = self_attn.rotary_embedding is not None
- self.blocks = self._build(self_attn, norm, ff, cross_attn)
-
- def _build(
- self,
- self_attn: Attention,
- norm: Type[nn.Module],
- ff: FeedForward,
- cross_attn: Optional[Attention],
- ) -> nn.ModuleDict:
- return nn.ModuleDict(
- {
- self.layers[0]: nn.ModuleList([norm, self_attn]),
- self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]),
- self.layers[2]: nn.ModuleList([deepcopy(norm), ff]),
- }
- )
-
- def _apply_block(
- self,
- layer: str,
- x: Tensor,
- context: Optional[Tensor] = None,
- input_mask: Optional[Tensor] = None,
- context_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Applies block function."""
- residual = x
- norm_fn, layer_fn = self.blocks[layer]
- if layer == "self_attn":
- out = layer_fn(x=x, input_mask=input_mask)
- elif layer == "cross_attn":
- out = layer_fn(
- x=x, context=context, input_mask=input_mask, context_mask=context_mask
- )
- else:
- out = layer_fn(x)
- out += residual
- return norm_fn(out)
-
- def forward(
- self,
- x: Tensor,
- context: Optional[Tensor] = None,
- input_mask: Optional[Tensor] = None,
- context_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Applies decoder block on input signals."""
- for layer in self.layers:
- x = self._apply_block(
- layer=layer,
- x=x,
- context=context,
- input_mask=input_mask,
- context_mask=context_mask,
- )
- return x
-
-
class Decoder(nn.Module):
"""Decoder Network."""
- def __init__(self, depth: int, block: DecoderBlock) -> None:
+ def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None:
super().__init__()
self.depth = depth
self.has_pos_emb = block.has_pos_emb
self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
+ self.ln = nn.LayerNorm(dim)
def forward(
self,
@@ -100,4 +31,4 @@ class Decoder(nn.Module):
x = block(
x=x, context=context, input_mask=input_mask, context_mask=context_mask
)
- return x
+ return self.ln(x)
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py
new file mode 100644
index 0000000..e6e7fb8
--- /dev/null
+++ b/text_recognizer/networks/transformer/decoder_block.py
@@ -0,0 +1,46 @@
+"""Transformer decoder module."""
+from copy import deepcopy
+from typing import Optional, Type
+
+from torch import Tensor, nn
+
+from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.ff import FeedForward
+
+
+class DecoderBlock(nn.Module):
+ """Decoder block."""
+
+ def __init__(
+ self,
+ self_attn: Attention,
+ norm: Type[nn.Module],
+ ff: FeedForward,
+ cross_attn: Optional[Attention] = None,
+ ) -> None:
+ super().__init__()
+ self.ln_attn = norm
+ self.attn = self_attn
+ self.ln_cross_attn = deepcopy(norm)
+ self.cross_attn = cross_attn
+ self.ln_ff = deepcopy(norm)
+ self.ff = ff
+ self.has_pos_emb = self.attn.rotary_embedding is not None
+
+ def forward(
+ self,
+ x: Tensor,
+ context: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
+ context_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Applies decoder block on input signals."""
+ x = x + self.attn(self.ln_attn(x), input_mask=input_mask)
+ x += self.cross_attn(
+ x=self.ln_cross_attn(x),
+ context=context,
+ input_mask=input_mask,
+ context_mask=context_mask,
+ )
+ x += self.ff(self.ln_ff(x))
+ return x