diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 30 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 58 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv.py | 9 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 40 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 91 | ||||
-rw-r--r-- | text_recognizer/networks/util.py | 10 |
7 files changed, 135 insertions, 105 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index e030cb8..ce7ec43 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -7,6 +7,7 @@ import torch from torch import nn, Tensor from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, @@ -15,7 +16,7 @@ from text_recognizer.networks.transformer.positional_encodings import ( @attr.s -class CnnTransformer(nn.Module): +class Reader(nn.Module): def __attrs_pre_init__(self) -> None: super().__init__() @@ -27,21 +28,20 @@ class CnnTransformer(nn.Module): num_classes: int = attr.ib() padding_idx: int = attr.ib() start_token: str = attr.ib() - start_index: int = attr.ib(init=False, default=None) + start_index: int = attr.ib(init=False) end_token: str = attr.ib() - end_index: int = attr.ib(init=False, default=None) + end_index: int = attr.ib(init=False) pad_token: str = attr.ib() - pad_index: int = attr.ib(init=False, default=None) + pad_index: int = attr.ib(init=False) # Modules. - encoder: Type[nn.Module] = attr.ib() + encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() - embedding: nn.Embedding = attr.ib(init=False, default=None) - latent_encoder: nn.Sequential = attr.ib(init=False, default=None) - token_embedding: nn.Embedding = attr.ib(init=False, default=None) - token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None) - head: nn.Linear = attr.ib(init=False, default=None) - mapping: AbstractMapping = attr.ib(init=False, default=None) + latent_encoder: nn.Sequential = attr.ib(init=False) + token_embedding: nn.Embedding = attr.ib(init=False) + token_pos_encoder: PositionalEncoding = attr.ib(init=False) + head: nn.Linear = attr.ib(init=False) + mapping: Type[AbstractMapping] = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -187,12 +187,16 @@ class CnnTransformer(nn.Module): output[:, i : i + 1] = tokens[-1:] # Early stopping of prediction loop if token is end or padding token. - if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all(): + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): break # Set all tokens after end token to pad token. for i in range(1, self.max_output_len): - idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index) + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) output[idx, i] = self.pad_index return output diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index 6719efb..a36150a 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,7 @@ """Efficient net.""" +from typing import Tuple + +import attr from torch import nn, Tensor from .mbconv import MBConvBlock @@ -9,10 +12,13 @@ from .utils import ( ) +@attr.s class EfficientNet(nn.Module): - # TODO: attr + def __attrs_pre_init__(self) -> None: + super().__init__() + archs = { - # width,depth0res,dropout + # width, depth, dropout "b0": (1.0, 1.0, 0.2), "b1": (1.0, 1.1, 0.2), "b2": (1.1, 1.2, 0.3), @@ -25,30 +31,30 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - def __init__( - self, - arch: str, - out_channels: int = 1280, - stochastic_dropout_rate: float = 0.2, - bn_momentum: float = 0.99, - bn_eps: float = 1.0e-3, - ) -> None: - super().__init__() - assert arch in self.archs, f"{arch} not a valid efficient net architecure!" - self.arch = self.archs[arch] - self.out_channels = out_channels - self.stochastic_dropout_rate = stochastic_dropout_rate - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self._conv_stem: nn.Sequential = None - self._blocks: nn.ModuleList = None - self._conv_head: nn.Sequential = None + arch: str = attr.ib() + params: Tuple[float, float, float] = attr.ib(default=None, init=False) + out_channels: int = attr.ib(default=1280) + stochastic_dropout_rate: float = attr.ib(default=0.2) + bn_momentum: float = attr.ib(default=0.99) + bn_eps: float = attr.ib(default=1.0e-3) + _conv_stem: nn.Sequential = attr.ib(default=None, init=False) + _blocks: nn.ModuleList = attr.ib(default=None, init=False) + _conv_head: nn.Sequential = attr.ib(default=None, init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" self._build() + @arch.validator + def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + if value not in self.archs: + raise ValueError(f"{value} not a valid architecure.") + self.params = self.archs[value] + def _build(self) -> None: _block_args = block_args() in_channels = 1 # BW - out_channels = round_filters(32, self.arch) + out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d( @@ -65,9 +71,9 @@ class EfficientNet(nn.Module): ) self._blocks = nn.ModuleList([]) for args in _block_args: - args.in_channels = round_filters(args.in_channels, self.arch) - args.out_channels = round_filters(args.out_channels, self.arch) - args.num_repeats = round_repeats(args.num_repeats, self.arch) + args.in_channels = round_filters(args.in_channels, self.params) + args.out_channels = round_filters(args.out_channels, self.params) + args.num_repeats = round_repeats(args.num_repeats, self.params) for _ in range(args.num_repeats): self._blocks.append( MBConvBlock( @@ -77,8 +83,8 @@ class EfficientNet(nn.Module): args.in_channels = args.out_channels args.stride = 1 - in_channels = round_filters(320, self.arch) - out_channels = round_filters(self.out_channels, self.arch) + in_channels = round_filters(320, self.params) + out_channels = round_filters(self.out_channels, self.params) self._conv_head = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d( diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index e43771a..3aa63d0 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -26,7 +26,7 @@ class MBConvBlock(nn.Module): ) -> None: super().__init__() self.kernel_size = kernel_size - self.stride = (stride, ) * 2 if isinstance(stride, int) else stride + self.stride = (stride,) * 2 if isinstance(stride, int) else stride self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.in_channels = in_channels @@ -68,8 +68,7 @@ class MBConvBlock(nn.Module): inner_channels = in_channels * expand_ratio self._inverted_bottleneck = ( self._configure_inverted_bottleneck( - in_channels=in_channels, - out_channels=inner_channels, + in_channels=in_channels, out_channels=inner_channels, ) if expand_ratio != 1 else None @@ -98,9 +97,7 @@ class MBConvBlock(nn.Module): ) def _configure_inverted_bottleneck( - self, - in_channels: int, - out_channels: int, + self, in_channels: int, out_channels: int, ) -> nn.Sequential: """Expansion phase.""" return nn.Sequential( diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index a3f3011..51de619 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1 +1,3 @@ """Transformer modules.""" +from .layers import Decoder, Encoder +from .transformer import Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 7bafc58..2770dc1 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,6 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple +import attr from einops import rearrange from einops.layers.torch import Rearrange import torch @@ -14,31 +15,38 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding ) +@attr.s class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - dim_head: int = 64, - dropout_rate: float = 0.0, - causal: bool = False, - ) -> None: + """Standard attention.""" + + def __attrs_pre_init__(self) -> None: super().__init__() - self.scale = dim ** -0.5 - self.num_heads = num_heads - self.causal = causal - inner_dim = dim * dim_head + + dim: int = attr.ib() + num_heads: int = attr.ib() + dim_head: int = attr.ib(default=64) + dropout_rate: float = attr.ib(default=0.0) + casual: bool = attr.ib(default=False) + scale: float = attr.ib(init=False) + dropout: nn.Dropout = attr.ib(init=False) + fc: nn.Linear = attr.ib(init=False) + qkv_fn: nn.Sequential = attr.ib(init=False) + attn_fn: F.softmax = attr.ib(init=False, default=F.softmax) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.scale = self.dim ** -0.5 + inner_dim = self.dim * self.dim_head # Attnetion self.qkv_fn = nn.Sequential( - nn.Linear(dim, 3 * inner_dim, bias=False), + nn.Linear(self.dim, 3 * inner_dim, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), ) - self.dropout = nn.Dropout(dropout_rate) - self.attn_fn = F.softmax + self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward - self.fc = nn.Linear(inner_dim, dim) + self.fc = nn.Linear(inner_dim, self.dim) @staticmethod def _apply_rotary_emb( diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 4daa265..9b2f236 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,67 +1,74 @@ """Transformer attention layer.""" from functools import partial -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple +import attr from torch import nn, Tensor -from .attention import Attention -from .mlp import FeedForward -from .residual import Residual -from .positional_encodings.rotary_embedding import RotaryEmbedding +from text_recognizer.networks.transformer.residual import Residual +from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import ( + RotaryEmbedding, +) +from text_recognizer.networks.util import load_partial_fn +@attr.s class AttentionLayers(nn.Module): - def __init__( - self, - dim: int, - depth: int, - num_heads: int, - ff_kwargs: Dict, - attn_kwargs: Dict, - attn_fn: Type[nn.Module] = Attention, - norm_fn: Type[nn.Module] = nn.LayerNorm, - ff_fn: Type[nn.Module] = FeedForward, - rotary_emb: Optional[Type[nn.Module]] = None, - rotary_emb_dim: Optional[int] = None, - causal: bool = False, - cross_attend: bool = False, - pre_norm: bool = True, - ) -> None: + """Standard transfomer layer.""" + + def __attrs_pre_init__(self) -> None: super().__init__() - self.dim = dim - attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) - norm_fn = partial(norm_fn, dim) - ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) - self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn) - rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None - self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None - self.pre_norm = pre_norm - self.has_pos_emb = True if self.rotary_emb is not None else False - @staticmethod - def _get_layer_types(cross_attend: bool) -> Tuple: + dim: int = attr.ib() + depth: int = attr.ib() + num_heads: int = attr.ib() + attn_fn: str = attr.ib() + attn_kwargs: Dict = attr.ib() + norm_fn: str = attr.ib() + ff_fn: str = attr.ib() + ff_kwargs: Dict = attr.ib() + causal: bool = attr.ib(default=False) + cross_attend: bool = attr.ib(default=False) + pre_norm: bool = attr.ib(default=True) + rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False) + has_pos_emb: bool = attr.ib(init=False) + layer_types: Tuple[str, ...] = attr.ib(init=False) + layers: nn.ModuleList = attr.ib(init=False) + attn: partial = attr.ib(init=False) + norm: partial = attr.ib(init=False) + ff: partial = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.has_pos_emb = True if self.rotary_emb is not None else False + self.layer_types = self._get_layer_types() * self.depth + attn = load_partial_fn( + self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs + ) + norm = load_partial_fn(self.norm_fn, dim=self.dim) + ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) + self.layers = self._build_network(attn, norm, ff) + + def _get_layer_types(self) -> Tuple: """Get layer specification.""" - if cross_attend: + if self.cross_attend: return "a", "c", "f" return "a", "f" def _build_network( - self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial, + self, attn: partial, norm: partial, ff: partial, ) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) for layer_type in self.layer_types: if layer_type == "a": - layer = attn_fn(causal=causal) + layer = attn(causal=self.causal) elif layer_type == "c": - layer = attn_fn() + layer = attn() elif layer_type == "f": - layer = ff_fn() - + layer = ff() residual_fn = Residual() - - layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + layers.append(nn.ModuleList([norm(), layer, residual_fn])) return layers def forward( @@ -72,12 +79,10 @@ class AttentionLayers(nn.Module): context_mask: Optional[Tensor] = None, ) -> Tensor: rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None - for i, (layer_type, (norm, block, residual_fn)) in enumerate( zip(self.layer_types, self.layers) ): is_last = i == len(self.layers) - 1 - residual = x if self.pre_norm: diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 85094f1..e822c57 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,5 +1,7 @@ """Miscellaneous neural network utility functionality.""" -from typing import Type +from functools import partial +from importlib import import_module +from typing import Any, Type from torch import nn @@ -19,3 +21,9 @@ def activation_function(activation: str) -> Type[nn.Module]: ] ) return activation_fns[activation.lower()] + + +def load_partial_fn(fn: str, **kwargs: Any) -> partial: + """Loads partial function.""" + module = import_module(".".join(fn.split(".")[:-1])) + return partial(getattr(module, fn.split(".")[0]), **kwargs) |