diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 82 |
1 files changed, 30 insertions, 52 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index ca4569f..941c141 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,13 +1,14 @@ """Transformer attention layer.""" +from copy import deepcopy from typing import Any, Optional, Tuple, Type import attr -from omegaconf.dictconfig import DictConfig from torch import nn, Tensor -from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding +from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.local_attention import LocalAttention +from text_recognizer.networks.transformer.mlp import FeedForward from text_recognizer.networks.transformer.residual import Residual -from text_recognizer.networks.util import load_partial_fn @attr.s(eq=False) @@ -18,78 +19,58 @@ class AttentionLayers(nn.Module): """Pre init constructor.""" super().__init__() - dim: int = attr.ib() depth: int = attr.ib() - num_heads: int = attr.ib() - attn_fn: str = attr.ib() - attn_kwargs: DictConfig = attr.ib() - norm_fn: str = attr.ib() - ff_fn: str = attr.ib() - ff_kwargs: DictConfig = attr.ib() - rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) - local_attn_fn: Optional[str] = attr.ib(default=None) - local_attn_kwargs: Optional[DictConfig] = attr.ib(default=None) - causal: bool = attr.ib(default=False) - cross_attend: bool = attr.ib(default=False) + self_attn: Attention = attr.ib() + norm: Type[nn.Module] = attr.ib() + ff: FeedForward = attr.ib() + cross_attn: Optional[Attention] = attr.ib(default=None) + local_self_attn: Optional[LocalAttention] = attr.ib(default=None) pre_norm: bool = attr.ib(default=True) local_depth: Optional[int] = attr.ib(default=None) + has_pos_emb: bool = attr.ib(default=False) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) - has_pos_emb: bool = attr.ib(init=False, default=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" - if self.local_attn_kwargs is not None and self.local_attn_fn is not None: - if "depth" not in self.local_attn_kwargs: + if self.local_self_attn is not None: + if self.local_depth is None: ValueError("Local depth has to be specified") - self.local_depth = self.local_attn_kwargs.pop("depth") self.layer_types = self._get_layer_types() * self.depth self.layers = self._build_network() - self.has_pos_emb = self.rotary_emb is None def _get_layer_types(self) -> Tuple: """Get layer specification.""" - if self.cross_attend: + if self.cross_attn is not None: return "a", "c", "f" return "a", "f" - def _configure_causal_attn(self, i: int) -> Type[nn.Module]: + def _self_attn_block(self, i: int) -> Type[nn.Module]: if self.local_depth is not None and i < self.local_depth: - return load_partial_fn( - self.local_attn_fn, - dim=self.dim, - num_heads=self.num_heads, - **self.local_attn_kwargs, - )() - return load_partial_fn( - self.attn_fn, - causal=self.causal, - dim=self.dim, - num_heads=self.num_heads, - **self.attn_kwargs, - )() + return deepcopy(self.local_self_attn) + return deepcopy(self.self_attn) + + def _delete(self) -> None: + del self.self_attn + del self.local_self_attn + del self.ff + del self.norm + del self.cross_attn def _build_network(self) -> nn.ModuleList: """Configures transformer network.""" - norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) - ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) layers = nn.ModuleList([]) self_attn_depth = 0 for layer_type in self.layer_types: if layer_type == "a": - layer = self._configure_causal_attn(self_attn_depth) + layer = self._self_attn_block(self_attn_depth) self_attn_depth += 1 elif layer_type == "c": - layer = load_partial_fn( - self.attn_fn, - dim=self.dim, - num_heads=self.num_heads, - **self.attn_kwargs, - )() + layer = deepcopy(self.cross_attn) elif layer_type == "f": - layer = ff() - residual_fn = Residual() - layers.append(nn.ModuleList([norm(), layer, residual_fn])) + layer = deepcopy(self.ff) + layers.append(nn.ModuleList([deepcopy(self.norm), layer, Residual()])) + self._delete() return layers def forward( @@ -100,7 +81,6 @@ class AttentionLayers(nn.Module): context_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass.""" - rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None for i, (layer_type, (norm, block, residual_fn)) in enumerate( zip(self.layer_types, self.layers) ): @@ -111,7 +91,7 @@ class AttentionLayers(nn.Module): x = norm(x) if layer_type == "a": - out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb) + out, _ = block(x=x, mask=mask) elif layer_type == "c": out, _ = block(x, context=context, mask=mask, context_mask=context_mask) elif layer_type == "f": @@ -129,6 +109,4 @@ class Decoder(AttentionLayers): """Decoder module.""" def __init__(self, **kwargs: Any) -> None: - if "causal" in kwargs: - ValueError("Cannot set causality on decoder") - super().__init__(causal=True, **kwargs) + super().__init__(**kwargs) |