diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index b8eb53b..2237acf 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -35,10 +35,10 @@ class EfficientNet(nn.Module): arch: str = attr.ib() params: Tuple[float, float, float] = attr.ib(default=None, init=False) - out_channels: int = attr.ib(default=1280) stochastic_dropout_rate: float = attr.ib(default=0.2) bn_momentum: float = attr.ib(default=0.99) bn_eps: float = attr.ib(default=1.0e-3) + out_channels: int = attr.ib(default=None, init=False) _conv_stem: nn.Sequential = attr.ib(default=None, init=False) _blocks: nn.ModuleList = attr.ib(default=None, init=False) _conv_head: nn.Sequential = attr.ib(default=None, init=False) @@ -89,11 +89,15 @@ class EfficientNet(nn.Module): args.stride = 1 in_channels = round_filters(320, self.params) - out_channels = round_filters(self.out_channels, self.params) + self.out_channels = round_filters(1280, self.params) self._conv_head = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.Conv2d( + in_channels, self.out_channels, kernel_size=1, stride=1, bias=False + ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), ) |