diff options
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/efficientnet.py')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 118 |
1 files changed, 0 insertions, 118 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py deleted file mode 100644 index 9454514..0000000 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Efficientnet backbone.""" -from typing import Tuple - -import attr -from torch import nn, Tensor - -from .mbconv import MBConvBlock -from .utils import ( - block_args, - round_filters, - round_repeats, -) - - -@attr.s(eq=False) -class EfficientNet(nn.Module): - """Efficientnet without classification head.""" - - def __attrs_pre_init__(self) -> None: - super().__init__() - - archs = { - # width, depth, dropout - "b0": (1.0, 1.0, 0.2), - "b1": (1.0, 1.1, 0.2), - "b2": (1.1, 1.2, 0.3), - "b3": (1.2, 1.4, 0.3), - "b4": (1.4, 1.8, 0.4), - "b5": (1.6, 2.2, 0.4), - "b6": (1.8, 2.6, 0.5), - "b7": (2.0, 3.1, 0.5), - "b8": (2.2, 3.6, 0.5), - "l2": (4.3, 5.3, 0.5), - } - - arch: str = attr.ib() - params: Tuple[float, float, float] = attr.ib(default=None, init=False) - stochastic_dropout_rate: float = attr.ib(default=0.2) - bn_momentum: float = attr.ib(default=0.99) - bn_eps: float = attr.ib(default=1.0e-3) - out_channels: int = attr.ib(default=None, init=False) - _conv_stem: nn.Sequential = attr.ib(default=None, init=False) - _blocks: nn.ModuleList = attr.ib(default=None, init=False) - _conv_head: nn.Sequential = attr.ib(default=None, init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self._build() - - @arch.validator - def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: - """Validates the efficientnet architecure.""" - if value not in self.archs: - raise ValueError(f"{value} not a valid architecure.") - self.params = self.archs[value] - - def _build(self) -> None: - """Builds the efficientnet backbone.""" - _block_args = block_args() - in_channels = 1 # BW - out_channels = round_filters(32, self.params) - self._conv_stem = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=(2, 2), - bias=False, - ), - nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps - ), - nn.Mish(inplace=True), - ) - self._blocks = nn.ModuleList([]) - for args in _block_args: - args.in_channels = round_filters(args.in_channels, self.params) - args.out_channels = round_filters(args.out_channels, self.params) - num_repeats = round_repeats(args.num_repeats, self.params) - del args.num_repeats - for _ in range(num_repeats): - self._blocks.append( - MBConvBlock( - **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, - ) - ) - args.in_channels = args.out_channels - args.stride = 1 - - in_channels = round_filters(320, self.params) - self.out_channels = round_filters(1280, self.params) - self._conv_head = nn.Sequential( - nn.Conv2d( - in_channels, self.out_channels, kernel_size=1, stride=1, bias=False - ), - nn.BatchNorm2d( - num_features=self.out_channels, - momentum=self.bn_momentum, - eps=self.bn_eps, - ), - nn.Dropout(p=self.params[-1]), - ) - - def extract_features(self, x: Tensor) -> Tensor: - """Extracts the final feature map layer.""" - x = self._conv_stem(x) - for i, block in enumerate(self._blocks): - stochastic_dropout_rate = self.stochastic_dropout_rate - if self.stochastic_dropout_rate: - stochastic_dropout_rate *= i / len(self._blocks) - x = block(x, stochastic_dropout_rate=stochastic_dropout_rate) - x = self._conv_head(x) - return x - - def forward(self, x: Tensor) -> Tensor: - """Returns efficientnet image features.""" - return self.extract_features(x) |