summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 00:22:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 00:22:27 +0200
commit80e9bed0dd9840ac0cc9de1c6c1be3b6fed90cf9 (patch)
tree1e8dec8051220b2673cfa254991165cc6820b203 /text_recognizer/networks/encoders
parent437ba4e22b2dad2ca309085a2e97f33dd03eb642 (diff)
Add new updates to model and networks, reimplementing with attr
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index a59abf8..fb4f002 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -27,7 +27,7 @@ class EfficientNet(nn.Module):
def __init__(
self,
arch: str,
- out_channels: int = 256,
+ out_channels: int = 1280,
stochastic_dropout_rate: float = 0.2,
bn_momentum: float = 0.99,
bn_eps: float = 1.0e-3,
@@ -37,7 +37,7 @@ class EfficientNet(nn.Module):
self.arch = self.archs[arch]
self.out_channels = out_channels
self.stochastic_dropout_rate = stochastic_dropout_rate
- self.bn_momentum = 1 - bn_momentum
+ self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self._conv_stem: nn.Sequential = None
self._blocks: nn.Sequential = None
@@ -70,9 +70,7 @@ class EfficientNet(nn.Module):
for _ in range(args.num_repeats):
self._blocks.append(
MBConvBlock(
- **args,
- bn_momentum=self.bn_momentum,
- bn_eps=self.bn_eps,
+ **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
)
)
args.in_channels = args.out_channels
@@ -94,7 +92,7 @@ class EfficientNet(nn.Module):
if self.stochastic_dropout_rate:
stochastic_dropout_rate *= i / len(self._blocks)
x = block(x, stochastic_dropout_rate=stochastic_dropout_rate)
- self._conv_head(x)
+ x = self._conv_head(x)
return x
def forward(self, x: Tensor) -> Tensor: