diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-22 22:26:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-22 22:26:39 +0200 |
commit | 42bf19d923c00c7be66d993d47cebf434035d1be (patch) | |
tree | b7442ce6a2f02a108a92670a527c95ab4b77fda1 /text_recognizer/networks/efficientnet | |
parent | 801ff44c68c4208c75c78fe400a72d82ef494778 (diff) |
Add extra conv layer to stem
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index c94c748..bd47e4b 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -59,7 +59,7 @@ class EfficientNet(nn.Module): """Builds the efficientnet backbone.""" _block_args = block_args()[: self.depth] in_channels = 1 # BW - out_channels = round_filters(16, self.params) + out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d( @@ -73,6 +73,17 @@ class EfficientNet(nn.Module): num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), nn.Mish(inplace=True), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + bias=False, + ), + nn.BatchNorm2d( + num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + ), + nn.Mish(inplace=True), ) self._blocks = nn.ModuleList([]) for args in _block_args: |