From 22e36513dd43d2e2ca82ca28a1ea757c5663676a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 26 Jun 2021 00:35:02 +0200 Subject: Updates --- text_recognizer/networks/encoders/efficientnet/efficientnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index 283b686..a59abf8 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -27,6 +27,7 @@ class EfficientNet(nn.Module): def __init__( self, arch: str, + out_channels: int = 256, stochastic_dropout_rate: float = 0.2, bn_momentum: float = 0.99, bn_eps: float = 1.0e-3, @@ -34,6 +35,7 @@ class EfficientNet(nn.Module): super().__init__() assert arch in self.archs, f"{arch} not a valid efficient net architecure!" self.arch = self.archs[arch] + self.out_channels = out_channels self.stochastic_dropout_rate = stochastic_dropout_rate self.bn_momentum = 1 - bn_momentum self.bn_eps = bn_eps @@ -77,7 +79,7 @@ class EfficientNet(nn.Module): args.stride = 1 in_channels = round_filters(320, self.arch) - out_channels = round_filters(1280, self.arch) + out_channels = round_filters(self.out_channels, self.arch) self._conv_head = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d( -- cgit v1.2.3-70-g09d2