diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-02 21:13:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-02 21:13:48 +0200 |
commit | 75801019981492eedf9280cb352eea3d8e99b65f (patch) | |
tree | 6521cc4134459e42591b2375f70acd348741474e /text_recognizer/networks | |
parent | e5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff) |
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/networks')
5 files changed, 75 insertions, 95 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 7371be4..09cc654 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -13,7 +13,7 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s +@attr.s(eq=False) class ConvTransformer(nn.Module): """Convolutional encoder and transformer decoder network.""" @@ -121,6 +121,7 @@ class ConvTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ + context = context.long() context_mask = context != self.pad_index context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index a36150a..b8eb53b 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,4 @@ -"""Efficient net.""" +"""Efficientnet backbone.""" from typing import Tuple import attr @@ -12,8 +12,10 @@ from .utils import ( ) -@attr.s +@attr.s(eq=False) class EfficientNet(nn.Module): + """Efficientnet without classification head.""" + def __attrs_pre_init__(self) -> None: super().__init__() @@ -47,11 +49,13 @@ class EfficientNet(nn.Module): @arch.validator def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") self.params = self.archs[value] def _build(self) -> None: + """Builds the efficientnet backbone.""" _block_args = block_args() in_channels = 1 # BW out_channels = round_filters(32, self.params) @@ -73,8 +77,9 @@ class EfficientNet(nn.Module): for args in _block_args: args.in_channels = round_filters(args.in_channels, self.params) args.out_channels = round_filters(args.out_channels, self.params) - args.num_repeats = round_repeats(args.num_repeats, self.params) - for _ in range(args.num_repeats): + num_repeats = round_repeats(args.num_repeats, self.params) + del args.num_repeats + for _ in range(num_repeats): self._blocks.append( MBConvBlock( **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, @@ -93,6 +98,7 @@ class EfficientNet(nn.Module): ) def extract_features(self, x: Tensor) -> Tensor: + """Extracts the final feature map layer.""" x = self._conv_stem(x) for i, block in enumerate(self._blocks): stochastic_dropout_rate = self.stochastic_dropout_rate @@ -103,4 +109,5 @@ class EfficientNet(nn.Module): return x def forward(self, x: Tensor) -> Tensor: + """Returns efficientnet image features.""" return self.extract_features(x) diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index 3aa63d0..e85df87 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,76 +1,62 @@ """Mobile inverted residual block.""" -from typing import Any, Optional, Union, Tuple +from typing import Optional, Sequence, Union, Tuple +import attr import torch from torch import nn, Tensor import torch.nn.functional as F -from .utils import stochastic_depth +from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth +def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: + """Converts int to tuple.""" + return ( + (stride,) * 2 if isinstance(stride, int) else stride + ) + + +@attr.s(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - bn_momentum: float, - bn_eps: float, - se_ratio: float, - expand_ratio: int, - *args: Any, - **kwargs: Any, - ) -> None: + def __attrs_pre_init__(self) -> None: super().__init__() - self.kernel_size = kernel_size - self.stride = (stride,) * 2 if isinstance(stride, int) else stride - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self.in_channels = in_channels - self.out_channels = out_channels + in_channels: int = attr.ib() + out_channels: int = attr.ib() + kernel_size: Tuple[int, int] = attr.ib() + stride: Tuple[int, int] = attr.ib(converter=_convert_stride) + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + se_ratio: float = attr.ib() + expand_ratio: int = attr.ib() + pad: Tuple[int, int, int, int] = attr.ib(init=False) + _inverted_bottleneck: nn.Sequential = attr.ib(init=False) + _depthwise: nn.Sequential = attr.ib(init=False) + _squeeze_excite: nn.Sequential = attr.ib(init=False) + _pointwise: nn.Sequential = attr.ib(init=False) + + @pad.default + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" if self.stride == (2, 2): - self.pad = [ + return ( (self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2, - ] * 2 - else: - self.pad = [(self.kernel_size - 1) // 2] * 4 - - # Placeholders for layers. - self._inverted_bottleneck: nn.Sequential = None - self._depthwise: nn.Sequential = None - self._squeeze_excite: nn.Sequential = None - self._pointwise: nn.Sequential = None - - self._build( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - expand_ratio=expand_ratio, - se_ratio=se_ratio, - ) + ) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() - def _build( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - expand_ratio: int, - se_ratio: float, - ) -> None: - has_se = se_ratio is not None and 0.0 < se_ratio < 1.0 - inner_channels = in_channels * expand_ratio + def _build(self) -> None: + has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 + inner_channels = self.in_channels * self.expand_ratio self._inverted_bottleneck = ( - self._configure_inverted_bottleneck( - in_channels=in_channels, out_channels=inner_channels, - ) - if expand_ratio != 1 + self._configure_inverted_bottleneck(out_channels=inner_channels) + if self.expand_ratio != 1 else None ) @@ -78,31 +64,23 @@ class MBConvBlock(nn.Module): in_channels=inner_channels, out_channels=inner_channels, groups=inner_channels, - kernel_size=kernel_size, - stride=stride, ) self._squeeze_excite = ( self._configure_squeeze_excite( - in_channels=inner_channels, - out_channels=inner_channels, - se_ratio=se_ratio, + in_channels=inner_channels, out_channels=inner_channels, ) if has_se else None ) - self._pointwise = self._configure_pointwise( - in_channels=inner_channels, out_channels=out_channels - ) + self._pointwise = self._configure_pointwise(in_channels=inner_channels) - def _configure_inverted_bottleneck( - self, in_channels: int, out_channels: int, - ) -> nn.Sequential: + def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential: """Expansion phase.""" return nn.Sequential( nn.Conv2d( - in_channels=in_channels, + in_channels=self.in_channels, out_channels=out_channels, kernel_size=1, bias=False, @@ -114,19 +92,14 @@ class MBConvBlock(nn.Module): ) def _configure_depthwise( - self, - in_channels: int, - out_channels: int, - groups: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], + self, in_channels: int, out_channels: int, groups: int, ) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, + kernel_size=self.kernel_size, + stride=self.stride, groups=groups, bias=False, ), @@ -137,9 +110,9 @@ class MBConvBlock(nn.Module): ) def _configure_squeeze_excite( - self, in_channels: int, out_channels: int, se_ratio: float + self, in_channels: int, out_channels: int ) -> nn.Sequential: - num_squeezed_channels = max(1, int(in_channels * se_ratio)) + num_squeezed_channels = max(1, int(in_channels * self.se_ratio)) return nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -154,18 +127,18 @@ class MBConvBlock(nn.Module): ), ) - def _configure_pointwise( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: + def _configure_pointwise(self, in_channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, - out_channels=out_channels, + out_channels=self.out_channels, kernel_size=1, bias=False, ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), ) @@ -186,8 +159,8 @@ class MBConvBlock(nn.Module): residual = x if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) - x = F.pad(x, self.pad) + x = F.pad(x, self.pad) x = self._depthwise(x) if self._squeeze_excite is not None: diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 9202cce..37ce29e 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -15,7 +15,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding ) -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Standard attention.""" @@ -31,7 +31,6 @@ class Attention(nn.Module): dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) qkv_fn: nn.Sequential = attr.ib(init=False) - attn_fn: F.softmax = attr.ib(init=False, default=F.softmax) def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -80,7 +79,7 @@ class Attention(nn.Module): else k_mask ) q_mask = rearrange(q_mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b i -> b () () j") + k_mask = rearrange(k_mask, "b j -> b () () j") return q_mask * k_mask return @@ -129,7 +128,7 @@ class Attention(nn.Module): if self.causal: energy = self._apply_causal_mask(energy, mask, mask_value, device) - attn = self.attn_fn(energy, dim=-1) + attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 66c9c50..ce443e5 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding from text_recognizer.networks.util import load_partial_fn -@attr.s +@attr.s(eq=False) class AttentionLayers(nn.Module): """Standard transfomer layer.""" @@ -101,11 +101,11 @@ class AttentionLayers(nn.Module): return x -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Encoder(AttentionLayers): causal: bool = attr.ib(default=False, init=False) -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Decoder(AttentionLayers): causal: bool = attr.ib(default=True, init=False) |