summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/efficientnet.py
blob: 2237acf08ffb5e73e747fee79844c4b82bc7530d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Efficientnet backbone."""
from typing import Tuple

import attr
from torch import nn, Tensor

from .mbconv import MBConvBlock
from .utils import (
    block_args,
    round_filters,
    round_repeats,
)


@attr.s(eq=False)
class EfficientNet(nn.Module):
    """Efficientnet without classification head."""

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    archs = {
        # width, depth, dropout
        "b0": (1.0, 1.0, 0.2),
        "b1": (1.0, 1.1, 0.2),
        "b2": (1.1, 1.2, 0.3),
        "b3": (1.2, 1.4, 0.3),
        "b4": (1.4, 1.8, 0.4),
        "b5": (1.6, 2.2, 0.4),
        "b6": (1.8, 2.6, 0.5),
        "b7": (2.0, 3.1, 0.5),
        "b8": (2.2, 3.6, 0.5),
        "l2": (4.3, 5.3, 0.5),
    }

    arch: str = attr.ib()
    params: Tuple[float, float, float] = attr.ib(default=None, init=False)
    stochastic_dropout_rate: float = attr.ib(default=0.2)
    bn_momentum: float = attr.ib(default=0.99)
    bn_eps: float = attr.ib(default=1.0e-3)
    out_channels: int = attr.ib(default=None, init=False)
    _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
    _blocks: nn.ModuleList = attr.ib(default=None, init=False)
    _conv_head: nn.Sequential = attr.ib(default=None, init=False)

    def __attrs_post_init__(self) -> None:
        """Post init configuration."""
        self._build()

    @arch.validator
    def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
        """Validates the efficientnet architecure."""
        if value not in self.archs:
            raise ValueError(f"{value} not a valid architecure.")
        self.params = self.archs[value]

    def _build(self) -> None:
        """Builds the efficientnet backbone."""
        _block_args = block_args()
        in_channels = 1  # BW
        out_channels = round_filters(32, self.params)
        self._conv_stem = nn.Sequential(
            nn.ZeroPad2d((0, 1, 0, 1)),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=(2, 2),
                bias=False,
            ),
            nn.BatchNorm2d(
                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
            ),
            nn.Mish(inplace=True),
        )
        self._blocks = nn.ModuleList([])
        for args in _block_args:
            args.in_channels = round_filters(args.in_channels, self.params)
            args.out_channels = round_filters(args.out_channels, self.params)
            num_repeats = round_repeats(args.num_repeats, self.params)
            del args.num_repeats
            for _ in range(num_repeats):
                self._blocks.append(
                    MBConvBlock(
                        **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
                    )
                )
                args.in_channels = args.out_channels
                args.stride = 1

        in_channels = round_filters(320, self.params)
        self.out_channels = round_filters(1280, self.params)
        self._conv_head = nn.Sequential(
            nn.Conv2d(
                in_channels, self.out_channels, kernel_size=1, stride=1, bias=False
            ),
            nn.BatchNorm2d(
                num_features=self.out_channels,
                momentum=self.bn_momentum,
                eps=self.bn_eps,
            ),
        )

    def extract_features(self, x: Tensor) -> Tensor:
        """Extracts the final feature map layer."""
        x = self._conv_stem(x)
        for i, block in enumerate(self._blocks):
            stochastic_dropout_rate = self.stochastic_dropout_rate
            if self.stochastic_dropout_rate:
                stochastic_dropout_rate *= i / len(self._blocks)
            x = block(x, stochastic_dropout_rate=stochastic_dropout_rate)
        x = self._conv_head(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        """Returns efficientnet image features."""
        return self.extract_features(x)