summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index b8eb53b..2237acf 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -35,10 +35,10 @@ class EfficientNet(nn.Module):
arch: str = attr.ib()
params: Tuple[float, float, float] = attr.ib(default=None, init=False)
- out_channels: int = attr.ib(default=1280)
stochastic_dropout_rate: float = attr.ib(default=0.2)
bn_momentum: float = attr.ib(default=0.99)
bn_eps: float = attr.ib(default=1.0e-3)
+ out_channels: int = attr.ib(default=None, init=False)
_conv_stem: nn.Sequential = attr.ib(default=None, init=False)
_blocks: nn.ModuleList = attr.ib(default=None, init=False)
_conv_head: nn.Sequential = attr.ib(default=None, init=False)
@@ -89,11 +89,15 @@ class EfficientNet(nn.Module):
args.stride = 1
in_channels = round_filters(320, self.params)
- out_channels = round_filters(self.out_channels, self.params)
+ self.out_channels = round_filters(1280, self.params)
self._conv_head = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+ nn.Conv2d(
+ in_channels, self.out_channels, kernel_size=1, stride=1, bias=False
+ ),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
)