From 80e9bed0dd9840ac0cc9de1c6c1be3b6fed90cf9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Jul 2021 00:22:27 +0200 Subject: Add new updates to model and networks, reimplementing with attr --- text_recognizer/networks/encoders/efficientnet/efficientnet.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index a59abf8..fb4f002 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -27,7 +27,7 @@ class EfficientNet(nn.Module): def __init__( self, arch: str, - out_channels: int = 256, + out_channels: int = 1280, stochastic_dropout_rate: float = 0.2, bn_momentum: float = 0.99, bn_eps: float = 1.0e-3, @@ -37,7 +37,7 @@ class EfficientNet(nn.Module): self.arch = self.archs[arch] self.out_channels = out_channels self.stochastic_dropout_rate = stochastic_dropout_rate - self.bn_momentum = 1 - bn_momentum + self.bn_momentum = bn_momentum self.bn_eps = bn_eps self._conv_stem: nn.Sequential = None self._blocks: nn.Sequential = None @@ -70,9 +70,7 @@ class EfficientNet(nn.Module): for _ in range(args.num_repeats): self._blocks.append( MBConvBlock( - **args, - bn_momentum=self.bn_momentum, - bn_eps=self.bn_eps, + **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, ) ) args.in_channels = args.out_channels @@ -94,7 +92,7 @@ class EfficientNet(nn.Module): if self.stochastic_dropout_rate: stochastic_dropout_rate *= i / len(self._blocks) x = block(x, stochastic_dropout_rate=stochastic_dropout_rate) - self._conv_head(x) + x = self._conv_head(x) return x def forward(self, x: Tensor) -> Tensor: -- cgit v1.2.3-70-g09d2