From 42bf19d923c00c7be66d993d47cebf434035d1be Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 22 Jun 2022 22:26:39 +0200 Subject: Add extra conv layer to stem --- text_recognizer/networks/efficientnet/efficientnet.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index c94c748..bd47e4b 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -59,7 +59,7 @@ class EfficientNet(nn.Module): """Builds the efficientnet backbone.""" _block_args = block_args()[: self.depth] in_channels = 1 # BW - out_channels = round_filters(16, self.params) + out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d( @@ -73,6 +73,17 @@ class EfficientNet(nn.Module): num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), nn.Mish(inplace=True), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + bias=False, + ), + nn.BatchNorm2d( + num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + ), + nn.Mish(inplace=True), ) self._blocks = nn.ModuleList([]) for args in _block_args: -- cgit v1.2.3-70-g09d2