summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:15:15 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:15:15 +0200
commit617c63a9ac1b10b2f093d115eef5d63d87a03658 (patch)
treebe0c580594ca312e55376f2ff8dd0b35edcda445 /text_recognizer/networks
parent43cf6e431b28b60b62d5689e42a591937d122154 (diff)
Add local attn in transformer layer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/layers.py47
1 files changed, 38 insertions, 9 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 2b8427d..b132522 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,7 +1,8 @@
"""Transformer attention layer."""
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Optional, Tuple, Type
import attr
+from omegaconf.dictconfig import DictConfig
from torch import nn, Tensor
from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
@@ -14,20 +15,24 @@ class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
dim: int = attr.ib()
depth: int = attr.ib()
num_heads: int = attr.ib()
attn_fn: str = attr.ib()
- attn_kwargs: Dict = attr.ib()
+ attn_kwargs: DictConfig = attr.ib()
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
- ff_kwargs: Dict = attr.ib()
+ ff_kwargs: DictConfig = attr.ib()
rotary_emb: Optional[RotaryEmbedding] = attr.ib()
+ local_attn_fn: Optional[str] = attr.ib(default=None)
+ local_attn_kwargs: Optional[DictConfig] = attr.ib(default=None)
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
+ local_depth: Optional[int] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
@@ -36,25 +41,48 @@ class AttentionLayers(nn.Module):
self.layer_types = self._get_layer_types() * self.depth
self.layers = self._build_network()
+ if self.local_attn_kwargs is not None and self.local_attn_fn is not None:
+ if "depth" not in self.local_attn_kwargs:
+ ValueError("Local depth has to be specified")
+ self.local_depth = self.local_attn_kwargs.pop("depth")
+
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
if self.cross_attend:
return "a", "c", "f"
return "a", "f"
+ def _configure_causal_attn(self, i: int) -> Type[nn.Module]:
+ if self.local_depth is not None and i <= self.local_depth:
+ return load_partial_fn(
+ self.local_attn_fn,
+ dim=self.dim,
+ num_heads=self.num_heads,
+ **self.local_attn_kwargs,
+ )()
+ return load_partial_fn(
+ self.attn_fn,
+ causal=self.causal,
+ dim=self.dim,
+ num_heads=self.num_heads,
+ **self.attn_kwargs,
+ )()
+
def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
- attn = load_partial_fn(
- self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
- )
norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
- for layer_type in self.layer_types:
+ for i, layer_type in enumerate(self.layer_types):
if layer_type == "a":
- layer = attn(causal=self.causal)
+ layer = self._configure_causal_attn(i)
elif layer_type == "c":
- layer = attn()
+ layer = load_partial_fn(
+ self.attn_fn,
+ dim=self.dim,
+ num_heads=self.num_heads,
+ **self.attn_kwargs,
+ )()
elif layer_type == "f":
layer = ff()
residual_fn = Residual()
@@ -68,6 +96,7 @@ class AttentionLayers(nn.Module):
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
+ """Forward pass."""
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
for i, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)