diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 00:36:55 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 00:36:55 +0200 | 
| commit | 548f52b35062e258622ea638ed1b132d6759a07a (patch) | |
| tree | e9262d0f934ac4f9392f20cb4fcf7be6033e3cb7 /text_recognizer/networks/transformer | |
| parent | 805d5726c17b83e00dcea0d2608dcd83a91f723d (diff) | |
Attention layer soon done
Diffstat (limited to 'text_recognizer/networks/transformer')
5 files changed, 87 insertions, 30 deletions
| diff --git a/text_recognizer/networks/transformer/attention_layers.py b/text_recognizer/networks/transformer/attention_layers.py deleted file mode 100644 index 721fa27..0000000 --- a/text_recognizer/networks/transformer/attention_layers.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Generates the attention layer architecture.""" -from typing import Type - -import torch -from torch import nn, Tensor - - -class AttentionLayers(nn.Module): -    def __init__( -        self, -        dim: int, -        depth: int, -        num_heads: int, -        norm_layer: Type[nn.Module], -        causal: bool = False, -        cross_attend: bool = False, -        only_cross: bool = False, -    ) -> None: -        pass diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py new file mode 100644 index 0000000..1c951ae --- /dev/null +++ b/text_recognizer/networks/transformer/layers.py @@ -0,0 +1,77 @@ +"""Generates the attention layer architecture.""" +from functools import partial +from typing import Dict, Optional, Type + +from click.types import Tuple + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .mlp import FeedForward +from .residual import Residual + + +class AttentionLayers(nn.Module): +    def __init__( +        self, +        dim: int, +        depth: int, +        num_heads: int, +        ff_kwargs: Dict, +        attn_kwargs: Dict, +        attn_fn: Type[nn.Module] = Attention, +        norm_fn: Type[nn.Module] = nn.LayerNorm, +        ff_fn: Type[nn.Module] = FeedForward, +        residual_fn: Type[nn.Module] = Residual, +        causal: bool = False, +        cross_attend: bool = False, +    ) -> None: +        super().__init__() +        attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) +        norm_fn = partial(norm_fn, dim=dim) +        ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) +        layer_types = self._get_layer_types(cross_attend) * depth +        self.layers = self._build_network( +            layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn +        ) + +    @staticmethod +    def _get_layer_types(cross_attend: bool) -> Tuple: +        """Get layer specification.""" +        if cross_attend: +            return "a", "c", "f" +        return "a", "f" + +    @staticmethod +    def _build_network( +        layer_types: Tuple, +        causal: bool, +        attn_fn: partial, +        norm_fn: partial, +        ff_fn: partial, +        residual_fn: Type[nn.Module], +    ) -> nn.ModuleList: +        """Configures transformer layers.""" +        layers = nn.ModuleList([]) +        for layer_type in layer_types: +            if layer_type == "a": +                layer = attn_fn(causal=causal) +            elif layer_type == "c": +                layer = attn_fn() +            elif layer_type == "f": +                layer = ff_fn() + +            residual_fn = residual_fn() + +            layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) +        return layers + +    def forward( +        self, +        x: Tensor, +        context: Optional[Tensor] = None, +        mask: Optional[Tensor] = None, +        context_mask: Optional[Tensor] = None, +    ) -> Tensor: +        pass diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 58c8770..8bc3221 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -11,17 +11,6 @@ from torch import nn  from torch import Tensor -class Rezero(nn.Module): -    def __init__(self, fn: Callable) -> None: -        super().__init__() -        self.fn = fn -        self.g = nn.Parameter(torch.zeros(1)) - -    def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: -        x, *rest = self.fn(x, **kwargs) -        return (x * self.g, *rest) - -  class ScaleNorm(nn.Module):      def __init__(self, dim: int, eps: float = 1.0e-5) -> None:          super().__init__() diff --git a/text_recognizer/networks/transformer/nystromer/__init__.py b/text_recognizer/networks/transformer/nystromer/__init__.py index e69de29..ea2c6fc 100644 --- a/text_recognizer/networks/transformer/nystromer/__init__.py +++ b/text_recognizer/networks/transformer/nystromer/__init__.py @@ -0,0 +1,2 @@ +"""Nyströmer module.""" +from .nystromer import Nystromer diff --git a/text_recognizer/networks/transformer/residual.py b/text_recognizer/networks/transformer/residual.py new file mode 100644 index 0000000..1547df6 --- /dev/null +++ b/text_recognizer/networks/transformer/residual.py @@ -0,0 +1,8 @@ +"""Residual function.""" +from torch import nn, Tensor + + +class Residual(nn.Module): +    def forward(self, x: Tensor, residual: Tensor) -> Tensor: +        """Applies the residual function.""" +        return x + residual |