summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/decoder.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index a58f7bd..1812e40 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -1,6 +1,6 @@
"""Transformer decoder module."""
from copy import deepcopy
-from typing import Optional, Tuple, Type
+from typing import Optional, Type
from torch import nn, Tensor
@@ -19,8 +19,8 @@ class DecoderBlock(nn.Module):
cross_attn: Optional[Attention] = None,
) -> None:
super().__init__()
- self._layers = ("self_attn", "cross_attn", "ff")
- self._blocks = self._build(self_attn, norm, ff, cross_attn)
+ self.layers = ("self_attn", "cross_attn", "ff")
+ self.blocks = self._build(self_attn, norm, ff, cross_attn)
def _build(
self,
@@ -37,7 +37,7 @@ class DecoderBlock(nn.Module):
}
)
- def _apply(
+ def _apply_block(
self,
layer: str,
x: Tensor,
@@ -45,8 +45,9 @@ class DecoderBlock(nn.Module):
input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
+ """Applies block function."""
residual = x
- norm_fn, layer_fn = self._blocks[layer]
+ norm_fn, layer_fn = self.blocks[layer]
if layer == "self_attn":
out = layer_fn(x=x, input_mask=input_mask)
elif layer == "cross_attn":
@@ -66,8 +67,8 @@ class DecoderBlock(nn.Module):
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Applies decoder block on input signals."""
- for layer in self._layers:
- x = self._apply(
+ for layer in self.layers:
+ x = self._apply_block(
layer=layer,
x=x,
context=context,
@@ -77,13 +78,14 @@ class DecoderBlock(nn.Module):
return x
-class Decoder:
+class Decoder(nn.Module):
"""Decoder Network."""
- def __init__(self, depth: int, block: DecoderBlock) -> None:
+ def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None:
+ super().__init__()
self.depth = depth
- self.has_pos_emb: bool = block.rotary_embedding is not None
- self._block = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
+ self.has_pos_emb = has_pos_emb
+ self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
def forward(
self,
@@ -93,7 +95,7 @@ class Decoder:
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Applies the network to the signals."""
- for block in self._blocks:
+ for block in self.blocks:
x = block(
x=x, context=context, input_mask=input_mask, context_mask=context_mask
)