diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 32 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py | 71 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 24 |
3 files changed, 65 insertions, 62 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 4c9ed75..cf64bcf 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,7 @@ """Efficientnet backbone.""" from typing import Tuple -import attr +from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@attr.s(eq=False) +@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" @@ -33,28 +33,28 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = attr.ib() - params: Tuple[float, float, float] = attr.ib(default=None, init=False) - stochastic_dropout_rate: float = attr.ib(default=0.2) - bn_momentum: float = attr.ib(default=0.99) - bn_eps: float = attr.ib(default=1.0e-3) - depth: int = attr.ib(default=7) - out_channels: int = attr.ib(default=None, init=False) - _conv_stem: nn.Sequential = attr.ib(default=None, init=False) - _blocks: nn.ModuleList = attr.ib(default=None, init=False) - _conv_head: nn.Sequential = attr.ib(default=None, init=False) + arch: str = field() + params: Tuple[float, float, float] = field(default=None, init=False) + stochastic_dropout_rate: float = field(default=0.2) + bn_momentum: float = field(default=0.99) + bn_eps: float = field(default=1.0e-3) + depth: int = field(default=7) + out_channels: int = field(default=None, init=False) + _conv_stem: nn.Sequential = field(default=None, init=False) + _blocks: nn.ModuleList = field(default=None, init=False) + _conv_head: nn.Sequential = field(default=None, init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self._build() @depth.validator - def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_depth(self, attribute, value: str) -> None: if not 5 <= value <= 7: raise ValueError(f"Depth has to be between 5 and 7, was: {value}") @arch.validator - def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_arch(self, attribute, value: str) -> None: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") @@ -88,7 +88,9 @@ class EfficientNet(nn.Module): for _ in range(num_repeats): self._blocks.append( MBConvBlock( - **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, + **args, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, ) ) args.in_channels = args.out_channels diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index beb7d57..98e9353 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,7 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -import attr +from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@attr.s(eq=False) +@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - block: nn.Sequential = attr.ib(init=False) + bn_momentum: float = field() + bn_eps: float = field() + block: nn.Sequential = field(init=False) def __attrs_pre_init__(self) -> None: super().__init__() @@ -36,12 +36,12 @@ class BaseModule(nn.Module): return self.block(x) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = attr.ib() - kernel_size: int = attr.ib() - stride: int = attr.ib() + channels: int = field() + kernel_size: int = field() + stride: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +85,13 @@ class Depthwise(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = attr.ib() - channels: int = attr.ib() - se_ratio: float = attr.ib() + in_channels: int = field() + channels: int = field() + se_ratio: float = field() def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -133,32 +133,35 @@ class Pointwise(BaseModule): ) -@attr.s(eq=False) +@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" def __attrs_pre_init__(self) -> None: super().__init__() - in_channels: int = attr.ib() - out_channels: int = attr.ib() - kernel_size: Tuple[int, int] = attr.ib() - stride: Tuple[int, int] = attr.ib(converter=_convert_stride) - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - se_ratio: float = attr.ib() - expand_ratio: int = attr.ib() - pad: Tuple[int, int, int, int] = attr.ib(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) - _depthwise: nn.Sequential = attr.ib(init=False) - _squeeze_excite: nn.Sequential = attr.ib(init=False) - _pointwise: nn.Sequential = attr.ib(init=False) + in_channels: int = field() + out_channels: int = field() + kernel_size: Tuple[int, int] = field() + stride: Tuple[int, int] = field(converter=_convert_stride) + bn_momentum: float = field() + bn_eps: float = field() + se_ratio: float = field() + expand_ratio: int = field() + pad: Tuple[int, int, int, int] = field(init=False) + _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) + _depthwise: nn.Sequential = field(init=False) + _squeeze_excite: nn.Sequential = field(init=False) + _pointwise: nn.Sequential = field(init=False) @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): - return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 return ((self.kernel_size - 1) // 2,) * 4 def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 87792a9..aa15b88 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,7 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -import attr +from attrs import define, field from einops import rearrange import torch from torch import einsum @@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import ( ) -@attr.s(eq=False) +@define(eq=False) class Attention(nn.Module): """Standard attention.""" def __attrs_pre_init__(self) -> None: super().__init__() - dim: int = attr.ib() - num_heads: int = attr.ib() - causal: bool = attr.ib(default=False) - dim_head: int = attr.ib(default=64) - dropout_rate: float = attr.ib(default=0.0) - rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) - scale: float = attr.ib(init=False) - dropout: nn.Dropout = attr.ib(init=False) - fc: nn.Linear = attr.ib(init=False) + dim: int = field() + num_heads: int = field() + causal: bool = field(default=False) + dim_head: int = field(default=64) + dropout_rate: float = field(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = field(default=None) + scale: float = field(init=False) + dropout: nn.Dropout = field(init=False) + fc: nn.Linear = field(init=False) def __attrs_post_init__(self) -> None: self.scale = self.dim ** -0.5 @@ -120,7 +120,6 @@ def apply_input_mask( input_mask = q_mask * k_mask energy = energy.masked_fill_(~input_mask, mask_value) - del input_mask return energy @@ -133,5 +132,4 @@ def apply_causal_mask( mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) energy.masked_fill_(mask, mask_value) - del mask return energy |