summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/layers.py27
1 files changed, 14 insertions, 13 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index b132522..ca4569f 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -26,7 +26,7 @@ class AttentionLayers(nn.Module):
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
ff_kwargs: DictConfig = attr.ib()
- rotary_emb: Optional[RotaryEmbedding] = attr.ib()
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
local_attn_fn: Optional[str] = attr.ib(default=None)
local_attn_kwargs: Optional[DictConfig] = attr.ib(default=None)
causal: bool = attr.ib(default=False)
@@ -35,16 +35,17 @@ class AttentionLayers(nn.Module):
local_depth: Optional[int] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
+ has_pos_emb: bool = attr.ib(init=False, default=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.layer_types = self._get_layer_types() * self.depth
- self.layers = self._build_network()
-
if self.local_attn_kwargs is not None and self.local_attn_fn is not None:
if "depth" not in self.local_attn_kwargs:
ValueError("Local depth has to be specified")
self.local_depth = self.local_attn_kwargs.pop("depth")
+ self.layer_types = self._get_layer_types() * self.depth
+ self.layers = self._build_network()
+ self.has_pos_emb = self.rotary_emb is None
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -53,7 +54,7 @@ class AttentionLayers(nn.Module):
return "a", "f"
def _configure_causal_attn(self, i: int) -> Type[nn.Module]:
- if self.local_depth is not None and i <= self.local_depth:
+ if self.local_depth is not None and i < self.local_depth:
return load_partial_fn(
self.local_attn_fn,
dim=self.dim,
@@ -73,9 +74,11 @@ class AttentionLayers(nn.Module):
norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
- for i, layer_type in enumerate(self.layer_types):
+ self_attn_depth = 0
+ for layer_type in self.layer_types:
if layer_type == "a":
- layer = self._configure_causal_attn(i)
+ layer = self._configure_causal_attn(self_attn_depth)
+ self_attn_depth += 1
elif layer_type == "c":
layer = load_partial_fn(
self.attn_fn,
@@ -122,12 +125,10 @@ class AttentionLayers(nn.Module):
return x
-@attr.s(auto_attribs=True, eq=False)
-class Encoder(AttentionLayers):
- causal: bool = attr.ib(default=False, init=False)
-
-
class Decoder(AttentionLayers):
+ """Decoder module."""
+
def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on decoder"
+ if "causal" in kwargs:
+ ValueError("Cannot set causality on decoder")
super().__init__(causal=True, **kwargs)