diff options
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index c94c748..bd47e4b 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -59,7 +59,7 @@ class EfficientNet(nn.Module): """Builds the efficientnet backbone.""" _block_args = block_args()[: self.depth] in_channels = 1 # BW - out_channels = round_filters(16, self.params) + out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d( @@ -73,6 +73,17 @@ class EfficientNet(nn.Module): num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), nn.Mish(inplace=True), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + bias=False, + ), + nn.BatchNorm2d( + num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + ), + nn.Mish(inplace=True), ) self._blocks = nn.ModuleList([]) for args in _block_args: |