From 548f52b35062e258622ea638ed1b132d6759a07a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 00:36:55 +0200 Subject: Attention layer soon done --- notebooks/00-scratch-pad.ipynb | 436 ++++++++++++++++++++- .../networks/transformer/attention_layers.py | 19 - text_recognizer/networks/transformer/layers.py | 77 ++++ text_recognizer/networks/transformer/norm.py | 11 - .../networks/transformer/nystromer/__init__.py | 2 + text_recognizer/networks/transformer/residual.py | 8 + 6 files changed, 517 insertions(+), 36 deletions(-) delete mode 100644 text_recognizer/networks/transformer/attention_layers.py create mode 100644 text_recognizer/networks/transformer/layers.py create mode 100644 text_recognizer/networks/transformer/residual.py diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index d50fd59..b6ec2c8 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -23,6 +23,410 @@ " sys.path.append('..')" ] }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "efficient_transformer = partial(Nystromer,\n", + " dim = 512,\n", + " depth = 12,\n", + " num_heads = 8,\n", + " num_landmarks = 256\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "functools.partial" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(efficient_transformer)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "efficient_transformer = efficient_transformer(num_landmarks=256)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Nystromer(\n", + " (layers): ModuleList(\n", + " (0): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (3): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (4): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (5): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (6): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (7): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (8): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (9): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (10): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (11): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "efficient_transformer()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(list(filter(lambda x: x == \"a\", (\"a\", \"c\") * 8)))" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -486,7 +890,7 @@ { "data": { "text/plain": [ - "144" + "18" ] }, "execution_count": 30, @@ -495,27 +899,47 @@ } ], "source": [ - "576 // 4" + "576 // 32" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "160" + "20" ] }, - "execution_count": 29, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 32" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "360" + ] + }, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "640 // 4" + "18 * 20" ] }, { diff --git a/text_recognizer/networks/transformer/attention_layers.py b/text_recognizer/networks/transformer/attention_layers.py deleted file mode 100644 index 721fa27..0000000 --- a/text_recognizer/networks/transformer/attention_layers.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Generates the attention layer architecture.""" -from typing import Type - -import torch -from torch import nn, Tensor - - -class AttentionLayers(nn.Module): - def __init__( - self, - dim: int, - depth: int, - num_heads: int, - norm_layer: Type[nn.Module], - causal: bool = False, - cross_attend: bool = False, - only_cross: bool = False, - ) -> None: - pass diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py new file mode 100644 index 0000000..1c951ae --- /dev/null +++ b/text_recognizer/networks/transformer/layers.py @@ -0,0 +1,77 @@ +"""Generates the attention layer architecture.""" +from functools import partial +from typing import Dict, Optional, Type + +from click.types import Tuple + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .mlp import FeedForward +from .residual import Residual + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + ff_kwargs: Dict, + attn_kwargs: Dict, + attn_fn: Type[nn.Module] = Attention, + norm_fn: Type[nn.Module] = nn.LayerNorm, + ff_fn: Type[nn.Module] = FeedForward, + residual_fn: Type[nn.Module] = Residual, + causal: bool = False, + cross_attend: bool = False, + ) -> None: + super().__init__() + attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) + norm_fn = partial(norm_fn, dim=dim) + ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) + layer_types = self._get_layer_types(cross_attend) * depth + self.layers = self._build_network( + layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn + ) + + @staticmethod + def _get_layer_types(cross_attend: bool) -> Tuple: + """Get layer specification.""" + if cross_attend: + return "a", "c", "f" + return "a", "f" + + @staticmethod + def _build_network( + layer_types: Tuple, + causal: bool, + attn_fn: partial, + norm_fn: partial, + ff_fn: partial, + residual_fn: Type[nn.Module], + ) -> nn.ModuleList: + """Configures transformer layers.""" + layers = nn.ModuleList([]) + for layer_type in layer_types: + if layer_type == "a": + layer = attn_fn(causal=causal) + elif layer_type == "c": + layer = attn_fn() + elif layer_type == "f": + layer = ff_fn() + + residual_fn = residual_fn() + + layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + return layers + + def forward( + self, + x: Tensor, + context: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + ) -> Tensor: + pass diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 58c8770..8bc3221 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -11,17 +11,6 @@ from torch import nn from torch import Tensor -class Rezero(nn.Module): - def __init__(self, fn: Callable) -> None: - super().__init__() - self.fn = fn - self.g = nn.Parameter(torch.zeros(1)) - - def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: - x, *rest = self.fn(x, **kwargs) - return (x * self.g, *rest) - - class ScaleNorm(nn.Module): def __init__(self, dim: int, eps: float = 1.0e-5) -> None: super().__init__() diff --git a/text_recognizer/networks/transformer/nystromer/__init__.py b/text_recognizer/networks/transformer/nystromer/__init__.py index e69de29..ea2c6fc 100644 --- a/text_recognizer/networks/transformer/nystromer/__init__.py +++ b/text_recognizer/networks/transformer/nystromer/__init__.py @@ -0,0 +1,2 @@ +"""Nyströmer module.""" +from .nystromer import Nystromer diff --git a/text_recognizer/networks/transformer/residual.py b/text_recognizer/networks/transformer/residual.py new file mode 100644 index 0000000..1547df6 --- /dev/null +++ b/text_recognizer/networks/transformer/residual.py @@ -0,0 +1,8 @@ +"""Residual function.""" +from torch import nn, Tensor + + +class Residual(nn.Module): + def forward(self, x: Tensor, residual: Tensor) -> Tensor: + """Applies the residual function.""" + return x + residual -- cgit v1.2.3-70-g09d2