diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 30 | ||||
| -rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 58 | ||||
| -rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv.py | 9 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 2 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/attention.py | 40 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/layers.py | 91 | ||||
| -rw-r--r-- | text_recognizer/networks/util.py | 10 | 
7 files changed, 135 insertions, 105 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index e030cb8..ce7ec43 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -7,6 +7,7 @@ import torch  from torch import nn, Tensor  from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.encoders.efficientnet import EfficientNet  from text_recognizer.networks.transformer.layers import Decoder  from text_recognizer.networks.transformer.positional_encodings import (      PositionalEncoding, @@ -15,7 +16,7 @@ from text_recognizer.networks.transformer.positional_encodings import (  @attr.s -class CnnTransformer(nn.Module): +class Reader(nn.Module):      def __attrs_pre_init__(self) -> None:          super().__init__() @@ -27,21 +28,20 @@ class CnnTransformer(nn.Module):      num_classes: int = attr.ib()      padding_idx: int = attr.ib()      start_token: str = attr.ib() -    start_index: int = attr.ib(init=False, default=None) +    start_index: int = attr.ib(init=False)      end_token: str = attr.ib() -    end_index: int = attr.ib(init=False, default=None) +    end_index: int = attr.ib(init=False)      pad_token: str = attr.ib() -    pad_index: int = attr.ib(init=False, default=None) +    pad_index: int = attr.ib(init=False)      # Modules. -    encoder: Type[nn.Module] = attr.ib() +    encoder: EfficientNet = attr.ib()      decoder: Decoder = attr.ib() -    embedding: nn.Embedding = attr.ib(init=False, default=None) -    latent_encoder: nn.Sequential = attr.ib(init=False, default=None) -    token_embedding: nn.Embedding = attr.ib(init=False, default=None) -    token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None) -    head: nn.Linear = attr.ib(init=False, default=None) -    mapping: AbstractMapping = attr.ib(init=False, default=None) +    latent_encoder: nn.Sequential = attr.ib(init=False) +    token_embedding: nn.Embedding = attr.ib(init=False) +    token_pos_encoder: PositionalEncoding = attr.ib(init=False) +    head: nn.Linear = attr.ib(init=False) +    mapping: Type[AbstractMapping] = attr.ib(init=False)      def __attrs_post_init__(self) -> None:          """Post init configuration.""" @@ -187,12 +187,16 @@ class CnnTransformer(nn.Module):              output[:, i : i + 1] = tokens[-1:]              # Early stopping of prediction loop if token is end or padding token. -            if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all(): +            if ( +                output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index +            ).all():                  break          # Set all tokens after end token to pad token.          for i in range(1, self.max_output_len): -            idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index) +            idx = ( +                output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index +            )              output[idx, i] = self.pad_index          return output diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index 6719efb..a36150a 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,7 @@  """Efficient net.""" +from typing import Tuple + +import attr  from torch import nn, Tensor  from .mbconv import MBConvBlock @@ -9,10 +12,13 @@ from .utils import (  ) +@attr.s  class EfficientNet(nn.Module): -    # TODO: attr +    def __attrs_pre_init__(self) -> None: +        super().__init__() +      archs = { -        #     width,depth0res,dropout +        # width, depth, dropout          "b0": (1.0, 1.0, 0.2),          "b1": (1.0, 1.1, 0.2),          "b2": (1.1, 1.2, 0.3), @@ -25,30 +31,30 @@ class EfficientNet(nn.Module):          "l2": (4.3, 5.3, 0.5),      } -    def __init__( -        self, -        arch: str, -        out_channels: int = 1280, -        stochastic_dropout_rate: float = 0.2, -        bn_momentum: float = 0.99, -        bn_eps: float = 1.0e-3, -    ) -> None: -        super().__init__() -        assert arch in self.archs, f"{arch} not a valid efficient net architecure!" -        self.arch = self.archs[arch] -        self.out_channels = out_channels -        self.stochastic_dropout_rate = stochastic_dropout_rate -        self.bn_momentum = bn_momentum -        self.bn_eps = bn_eps -        self._conv_stem: nn.Sequential = None -        self._blocks: nn.ModuleList = None -        self._conv_head: nn.Sequential = None +    arch: str = attr.ib() +    params: Tuple[float, float, float] = attr.ib(default=None, init=False) +    out_channels: int = attr.ib(default=1280) +    stochastic_dropout_rate: float = attr.ib(default=0.2) +    bn_momentum: float = attr.ib(default=0.99) +    bn_eps: float = attr.ib(default=1.0e-3) +    _conv_stem: nn.Sequential = attr.ib(default=None, init=False) +    _blocks: nn.ModuleList = attr.ib(default=None, init=False) +    _conv_head: nn.Sequential = attr.ib(default=None, init=False) + +    def __attrs_post_init__(self) -> None: +        """Post init configuration."""          self._build() +    @arch.validator +    def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: +        if value not in self.archs: +            raise ValueError(f"{value} not a valid architecure.") +        self.params = self.archs[value] +      def _build(self) -> None:          _block_args = block_args()          in_channels = 1  # BW -        out_channels = round_filters(32, self.arch) +        out_channels = round_filters(32, self.params)          self._conv_stem = nn.Sequential(              nn.ZeroPad2d((0, 1, 0, 1)),              nn.Conv2d( @@ -65,9 +71,9 @@ class EfficientNet(nn.Module):          )          self._blocks = nn.ModuleList([])          for args in _block_args: -            args.in_channels = round_filters(args.in_channels, self.arch) -            args.out_channels = round_filters(args.out_channels, self.arch) -            args.num_repeats = round_repeats(args.num_repeats, self.arch) +            args.in_channels = round_filters(args.in_channels, self.params) +            args.out_channels = round_filters(args.out_channels, self.params) +            args.num_repeats = round_repeats(args.num_repeats, self.params)              for _ in range(args.num_repeats):                  self._blocks.append(                      MBConvBlock( @@ -77,8 +83,8 @@ class EfficientNet(nn.Module):                  args.in_channels = args.out_channels                  args.stride = 1 -        in_channels = round_filters(320, self.arch) -        out_channels = round_filters(self.out_channels, self.arch) +        in_channels = round_filters(320, self.params) +        out_channels = round_filters(self.out_channels, self.params)          self._conv_head = nn.Sequential(              nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),              nn.BatchNorm2d( diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index e43771a..3aa63d0 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -26,7 +26,7 @@ class MBConvBlock(nn.Module):      ) -> None:          super().__init__()          self.kernel_size = kernel_size -        self.stride = (stride, ) * 2 if isinstance(stride, int) else stride +        self.stride = (stride,) * 2 if isinstance(stride, int) else stride          self.bn_momentum = bn_momentum          self.bn_eps = bn_eps          self.in_channels = in_channels @@ -68,8 +68,7 @@ class MBConvBlock(nn.Module):          inner_channels = in_channels * expand_ratio          self._inverted_bottleneck = (              self._configure_inverted_bottleneck( -                in_channels=in_channels, -                out_channels=inner_channels, +                in_channels=in_channels, out_channels=inner_channels,              )              if expand_ratio != 1              else None @@ -98,9 +97,7 @@ class MBConvBlock(nn.Module):          )      def _configure_inverted_bottleneck( -        self, -        in_channels: int, -        out_channels: int, +        self, in_channels: int, out_channels: int,      ) -> nn.Sequential:          """Expansion phase."""          return nn.Sequential( diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index a3f3011..51de619 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1 +1,3 @@  """Transformer modules.""" +from .layers import Decoder, Encoder +from .transformer import Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 7bafc58..2770dc1 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,6 +1,7 @@  """Implementes the attention module for the transformer."""  from typing import Optional, Tuple +import attr  from einops import rearrange  from einops.layers.torch import Rearrange  import torch @@ -14,31 +15,38 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding  ) +@attr.s  class Attention(nn.Module): -    def __init__( -        self, -        dim: int, -        num_heads: int, -        dim_head: int = 64, -        dropout_rate: float = 0.0, -        causal: bool = False, -    ) -> None: +    """Standard attention.""" + +    def __attrs_pre_init__(self) -> None:          super().__init__() -        self.scale = dim ** -0.5 -        self.num_heads = num_heads -        self.causal = causal -        inner_dim = dim * dim_head + +    dim: int = attr.ib() +    num_heads: int = attr.ib() +    dim_head: int = attr.ib(default=64) +    dropout_rate: float = attr.ib(default=0.0) +    casual: bool = attr.ib(default=False) +    scale: float = attr.ib(init=False) +    dropout: nn.Dropout = attr.ib(init=False) +    fc: nn.Linear = attr.ib(init=False) +    qkv_fn: nn.Sequential = attr.ib(init=False) +    attn_fn: F.softmax = attr.ib(init=False, default=F.softmax) + +    def __attrs_post_init__(self) -> None: +        """Post init configuration.""" +        self.scale = self.dim ** -0.5 +        inner_dim = self.dim * self.dim_head          # Attnetion          self.qkv_fn = nn.Sequential( -            nn.Linear(dim, 3 * inner_dim, bias=False), +            nn.Linear(self.dim, 3 * inner_dim, bias=False),              Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),          ) -        self.dropout = nn.Dropout(dropout_rate) -        self.attn_fn = F.softmax +        self.dropout = nn.Dropout(p=self.dropout_rate)          # Feedforward -        self.fc = nn.Linear(inner_dim, dim) +        self.fc = nn.Linear(inner_dim, self.dim)      @staticmethod      def _apply_rotary_emb( diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 4daa265..9b2f236 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,67 +1,74 @@  """Transformer attention layer."""  from functools import partial -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple +import attr  from torch import nn, Tensor -from .attention import Attention -from .mlp import FeedForward -from .residual import Residual -from .positional_encodings.rotary_embedding import RotaryEmbedding +from text_recognizer.networks.transformer.residual import Residual +from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import ( +    RotaryEmbedding, +) +from text_recognizer.networks.util import load_partial_fn +@attr.s  class AttentionLayers(nn.Module): -    def __init__( -        self, -        dim: int, -        depth: int, -        num_heads: int, -        ff_kwargs: Dict, -        attn_kwargs: Dict, -        attn_fn: Type[nn.Module] = Attention, -        norm_fn: Type[nn.Module] = nn.LayerNorm, -        ff_fn: Type[nn.Module] = FeedForward, -        rotary_emb: Optional[Type[nn.Module]] = None, -        rotary_emb_dim: Optional[int] = None, -        causal: bool = False, -        cross_attend: bool = False, -        pre_norm: bool = True, -    ) -> None: +    """Standard transfomer layer.""" + +    def __attrs_pre_init__(self) -> None:          super().__init__() -        self.dim = dim -        attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) -        norm_fn = partial(norm_fn, dim) -        ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) -        self.layer_types = self._get_layer_types(cross_attend) * depth -        self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn) -        rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None -        self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None -        self.pre_norm = pre_norm -        self.has_pos_emb = True if self.rotary_emb is not None else False -    @staticmethod -    def _get_layer_types(cross_attend: bool) -> Tuple: +    dim: int = attr.ib() +    depth: int = attr.ib() +    num_heads: int = attr.ib() +    attn_fn: str = attr.ib() +    attn_kwargs: Dict = attr.ib() +    norm_fn: str = attr.ib() +    ff_fn: str = attr.ib() +    ff_kwargs: Dict = attr.ib() +    causal: bool = attr.ib(default=False) +    cross_attend: bool = attr.ib(default=False) +    pre_norm: bool = attr.ib(default=True) +    rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False) +    has_pos_emb: bool = attr.ib(init=False) +    layer_types: Tuple[str, ...] = attr.ib(init=False) +    layers: nn.ModuleList = attr.ib(init=False) +    attn: partial = attr.ib(init=False) +    norm: partial = attr.ib(init=False) +    ff: partial = attr.ib(init=False) + +    def __attrs_post_init__(self) -> None: +        """Post init configuration.""" +        self.has_pos_emb = True if self.rotary_emb is not None else False +        self.layer_types = self._get_layer_types() * self.depth +        attn = load_partial_fn( +            self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs +        ) +        norm = load_partial_fn(self.norm_fn, dim=self.dim) +        ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) +        self.layers = self._build_network(attn, norm, ff) + +    def _get_layer_types(self) -> Tuple:          """Get layer specification.""" -        if cross_attend: +        if self.cross_attend:              return "a", "c", "f"          return "a", "f"      def _build_network( -        self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial, +        self, attn: partial, norm: partial, ff: partial,      ) -> nn.ModuleList:          """Configures transformer network."""          layers = nn.ModuleList([])          for layer_type in self.layer_types:              if layer_type == "a": -                layer = attn_fn(causal=causal) +                layer = attn(causal=self.causal)              elif layer_type == "c": -                layer = attn_fn() +                layer = attn()              elif layer_type == "f": -                layer = ff_fn() - +                layer = ff()              residual_fn = Residual() - -            layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) +            layers.append(nn.ModuleList([norm(), layer, residual_fn]))          return layers      def forward( @@ -72,12 +79,10 @@ class AttentionLayers(nn.Module):          context_mask: Optional[Tensor] = None,      ) -> Tensor:          rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None -          for i, (layer_type, (norm, block, residual_fn)) in enumerate(              zip(self.layer_types, self.layers)          ):              is_last = i == len(self.layers) - 1 -              residual = x              if self.pre_norm: diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 85094f1..e822c57 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,5 +1,7 @@  """Miscellaneous neural network utility functionality.""" -from typing import Type +from functools import partial +from importlib import import_module +from typing import Any, Type  from torch import nn @@ -19,3 +21,9 @@ def activation_function(activation: str) -> Type[nn.Module]:          ]      )      return activation_fns[activation.lower()] + + +def load_partial_fn(fn: str, **kwargs: Any) -> partial: +    """Loads partial function.""" +    module = import_module(".".join(fn.split(".")[:-1])) +    return partial(getattr(module, fn.split(".")[0]), **kwargs)  |