From ea525029b8b0355c656280e491796b4821c491a4 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 5 Nov 2021 19:25:42 +0100 Subject: Remove out_channels as a settable parameter in effnet --- .../networks/encoders/efficientnet/efficientnet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index b8eb53b..2237acf 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -35,10 +35,10 @@ class EfficientNet(nn.Module): arch: str = attr.ib() params: Tuple[float, float, float] = attr.ib(default=None, init=False) - out_channels: int = attr.ib(default=1280) stochastic_dropout_rate: float = attr.ib(default=0.2) bn_momentum: float = attr.ib(default=0.99) bn_eps: float = attr.ib(default=1.0e-3) + out_channels: int = attr.ib(default=None, init=False) _conv_stem: nn.Sequential = attr.ib(default=None, init=False) _blocks: nn.ModuleList = attr.ib(default=None, init=False) _conv_head: nn.Sequential = attr.ib(default=None, init=False) @@ -89,11 +89,15 @@ class EfficientNet(nn.Module): args.stride = 1 in_channels = round_filters(320, self.params) - out_channels = round_filters(self.out_channels, self.params) + self.out_channels = round_filters(1280, self.params) self._conv_head = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.Conv2d( + in_channels, self.out_channels, kernel_size=1, stride=1, bias=False + ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), ) -- cgit v1.2.3-70-g09d2