summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-01 00:35:41 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-01 00:35:41 +0100
commit5b2c729e819d1e1e5a6752a3952592259ea48f8a (patch)
treec391ca7dce7ad25e0d5a85fd816b8f0ee977e9b1
parent7808b54b5bd146bb3671bee5d4540513826e96ea (diff)
Refactor transformer layer
-rw-r--r--text_recognizer/networks/transformer/layers.py82
1 files changed, 30 insertions, 52 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index ca4569f..941c141 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,13 +1,14 @@
"""Transformer attention layer."""
+from copy import deepcopy
from typing import Any, Optional, Tuple, Type
import attr
-from omegaconf.dictconfig import DictConfig
from torch import nn, Tensor
-from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
+from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.local_attention import LocalAttention
+from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual
-from text_recognizer.networks.util import load_partial_fn
@attr.s(eq=False)
@@ -18,78 +19,58 @@ class AttentionLayers(nn.Module):
"""Pre init constructor."""
super().__init__()
- dim: int = attr.ib()
depth: int = attr.ib()
- num_heads: int = attr.ib()
- attn_fn: str = attr.ib()
- attn_kwargs: DictConfig = attr.ib()
- norm_fn: str = attr.ib()
- ff_fn: str = attr.ib()
- ff_kwargs: DictConfig = attr.ib()
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
- local_attn_fn: Optional[str] = attr.ib(default=None)
- local_attn_kwargs: Optional[DictConfig] = attr.ib(default=None)
- causal: bool = attr.ib(default=False)
- cross_attend: bool = attr.ib(default=False)
+ self_attn: Attention = attr.ib()
+ norm: Type[nn.Module] = attr.ib()
+ ff: FeedForward = attr.ib()
+ cross_attn: Optional[Attention] = attr.ib(default=None)
+ local_self_attn: Optional[LocalAttention] = attr.ib(default=None)
pre_norm: bool = attr.ib(default=True)
local_depth: Optional[int] = attr.ib(default=None)
+ has_pos_emb: bool = attr.ib(default=False)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
- has_pos_emb: bool = attr.ib(init=False, default=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- if self.local_attn_kwargs is not None and self.local_attn_fn is not None:
- if "depth" not in self.local_attn_kwargs:
+ if self.local_self_attn is not None:
+ if self.local_depth is None:
ValueError("Local depth has to be specified")
- self.local_depth = self.local_attn_kwargs.pop("depth")
self.layer_types = self._get_layer_types() * self.depth
self.layers = self._build_network()
- self.has_pos_emb = self.rotary_emb is None
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
- if self.cross_attend:
+ if self.cross_attn is not None:
return "a", "c", "f"
return "a", "f"
- def _configure_causal_attn(self, i: int) -> Type[nn.Module]:
+ def _self_attn_block(self, i: int) -> Type[nn.Module]:
if self.local_depth is not None and i < self.local_depth:
- return load_partial_fn(
- self.local_attn_fn,
- dim=self.dim,
- num_heads=self.num_heads,
- **self.local_attn_kwargs,
- )()
- return load_partial_fn(
- self.attn_fn,
- causal=self.causal,
- dim=self.dim,
- num_heads=self.num_heads,
- **self.attn_kwargs,
- )()
+ return deepcopy(self.local_self_attn)
+ return deepcopy(self.self_attn)
+
+ def _delete(self) -> None:
+ del self.self_attn
+ del self.local_self_attn
+ del self.ff
+ del self.norm
+ del self.cross_attn
def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
- norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
- ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
self_attn_depth = 0
for layer_type in self.layer_types:
if layer_type == "a":
- layer = self._configure_causal_attn(self_attn_depth)
+ layer = self._self_attn_block(self_attn_depth)
self_attn_depth += 1
elif layer_type == "c":
- layer = load_partial_fn(
- self.attn_fn,
- dim=self.dim,
- num_heads=self.num_heads,
- **self.attn_kwargs,
- )()
+ layer = deepcopy(self.cross_attn)
elif layer_type == "f":
- layer = ff()
- residual_fn = Residual()
- layers.append(nn.ModuleList([norm(), layer, residual_fn]))
+ layer = deepcopy(self.ff)
+ layers.append(nn.ModuleList([deepcopy(self.norm), layer, Residual()]))
+ self._delete()
return layers
def forward(
@@ -100,7 +81,6 @@ class AttentionLayers(nn.Module):
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward pass."""
- rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
for i, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
@@ -111,7 +91,7 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == "a":
- out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
+ out, _ = block(x=x, mask=mask)
elif layer_type == "c":
out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
elif layer_type == "f":
@@ -129,6 +109,4 @@ class Decoder(AttentionLayers):
"""Decoder module."""
def __init__(self, **kwargs: Any) -> None:
- if "causal" in kwargs:
- ValueError("Cannot set causality on decoder")
- super().__init__(causal=True, **kwargs)
+ super().__init__(**kwargs)