diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 47 |
1 files changed, 38 insertions, 9 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 2b8427d..b132522 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,7 +1,8 @@ """Transformer attention layer.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional, Tuple, Type import attr +from omegaconf.dictconfig import DictConfig from torch import nn, Tensor from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding @@ -14,20 +15,24 @@ class AttentionLayers(nn.Module): """Standard transfomer layer.""" def __attrs_pre_init__(self) -> None: + """Pre init constructor.""" super().__init__() dim: int = attr.ib() depth: int = attr.ib() num_heads: int = attr.ib() attn_fn: str = attr.ib() - attn_kwargs: Dict = attr.ib() + attn_kwargs: DictConfig = attr.ib() norm_fn: str = attr.ib() ff_fn: str = attr.ib() - ff_kwargs: Dict = attr.ib() + ff_kwargs: DictConfig = attr.ib() rotary_emb: Optional[RotaryEmbedding] = attr.ib() + local_attn_fn: Optional[str] = attr.ib(default=None) + local_attn_kwargs: Optional[DictConfig] = attr.ib(default=None) causal: bool = attr.ib(default=False) cross_attend: bool = attr.ib(default=False) pre_norm: bool = attr.ib(default=True) + local_depth: Optional[int] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) @@ -36,25 +41,48 @@ class AttentionLayers(nn.Module): self.layer_types = self._get_layer_types() * self.depth self.layers = self._build_network() + if self.local_attn_kwargs is not None and self.local_attn_fn is not None: + if "depth" not in self.local_attn_kwargs: + ValueError("Local depth has to be specified") + self.local_depth = self.local_attn_kwargs.pop("depth") + def _get_layer_types(self) -> Tuple: """Get layer specification.""" if self.cross_attend: return "a", "c", "f" return "a", "f" + def _configure_causal_attn(self, i: int) -> Type[nn.Module]: + if self.local_depth is not None and i <= self.local_depth: + return load_partial_fn( + self.local_attn_fn, + dim=self.dim, + num_heads=self.num_heads, + **self.local_attn_kwargs, + )() + return load_partial_fn( + self.attn_fn, + causal=self.causal, + dim=self.dim, + num_heads=self.num_heads, + **self.attn_kwargs, + )() + def _build_network(self) -> nn.ModuleList: """Configures transformer network.""" - attn = load_partial_fn( - self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs - ) norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) layers = nn.ModuleList([]) - for layer_type in self.layer_types: + for i, layer_type in enumerate(self.layer_types): if layer_type == "a": - layer = attn(causal=self.causal) + layer = self._configure_causal_attn(i) elif layer_type == "c": - layer = attn() + layer = load_partial_fn( + self.attn_fn, + dim=self.dim, + num_heads=self.num_heads, + **self.attn_kwargs, + )() elif layer_type == "f": layer = ff() residual_fn = Residual() @@ -68,6 +96,7 @@ class AttentionLayers(nn.Module): mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: + """Forward pass.""" rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None for i, (layer_type, (norm, block, residual_fn)) in enumerate( zip(self.layer_types, self.layers) |