diff options
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 2260ee2..2f2508d 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -31,11 +31,11 @@ class EfficientNet(nn.Module): def __init__( self, arch: str, - params: Tuple[float, float, float], stochastic_dropout_rate: float = 0.2, bn_momentum: float = 0.99, bn_eps: float = 1.0e-3, depth: int = 7, + out_channels: int = 1280, ) -> None: super().__init__() self.params = self._get_arch_params(arch) @@ -43,7 +43,7 @@ class EfficientNet(nn.Module): self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.depth = depth - self.out_channels: int + self.out_channels: int = out_channels self._conv_stem: nn.Sequential self._blocks: nn.ModuleList self._conv_head: nn.Sequential @@ -92,7 +92,6 @@ class EfficientNet(nn.Module): args.stride = 1 in_channels = round_filters(_block_args[-1].out_channels, self.params) - self.out_channels = round_filters(1280, self.params) self._conv_head = nn.Sequential( nn.Conv2d( in_channels, self.out_channels, kernel_size=1, stride=1, bias=False |