diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 00:22:27 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 00:22:27 +0200 |
commit | 80e9bed0dd9840ac0cc9de1c6c1be3b6fed90cf9 (patch) | |
tree | 1e8dec8051220b2673cfa254991165cc6820b203 /text_recognizer/networks/encoders | |
parent | 437ba4e22b2dad2ca309085a2e97f33dd03eb642 (diff) |
Add new updates to model and networks, reimplementing with attr
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index a59abf8..fb4f002 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -27,7 +27,7 @@ class EfficientNet(nn.Module): def __init__( self, arch: str, - out_channels: int = 256, + out_channels: int = 1280, stochastic_dropout_rate: float = 0.2, bn_momentum: float = 0.99, bn_eps: float = 1.0e-3, @@ -37,7 +37,7 @@ class EfficientNet(nn.Module): self.arch = self.archs[arch] self.out_channels = out_channels self.stochastic_dropout_rate = stochastic_dropout_rate - self.bn_momentum = 1 - bn_momentum + self.bn_momentum = bn_momentum self.bn_eps = bn_eps self._conv_stem: nn.Sequential = None self._blocks: nn.Sequential = None @@ -70,9 +70,7 @@ class EfficientNet(nn.Module): for _ in range(args.num_repeats): self._blocks.append( MBConvBlock( - **args, - bn_momentum=self.bn_momentum, - bn_eps=self.bn_eps, + **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, ) ) args.in_channels = args.out_channels @@ -94,7 +92,7 @@ class EfficientNet(nn.Module): if self.stochastic_dropout_rate: stochastic_dropout_rate *= i / len(self._blocks) x = block(x, stochastic_dropout_rate=stochastic_dropout_rate) - self._conv_head(x) + x = self._conv_head(x) return x def forward(self, x: Tensor) -> Tensor: |