summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-22 22:26:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-22 22:26:39 +0200
commit42bf19d923c00c7be66d993d47cebf434035d1be (patch)
treeb7442ce6a2f02a108a92670a527c95ab4b77fda1
parent801ff44c68c4208c75c78fe400a72d82ef494778 (diff)
Add extra conv layer to stem
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index c94c748..bd47e4b 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -59,7 +59,7 @@ class EfficientNet(nn.Module):
"""Builds the efficientnet backbone."""
_block_args = block_args()[: self.depth]
in_channels = 1 # BW
- out_channels = round_filters(16, self.params)
+ out_channels = round_filters(32, self.params)
self._conv_stem = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(
@@ -73,6 +73,17 @@ class EfficientNet(nn.Module):
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
nn.Mish(inplace=True),
+ nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ ),
+ nn.Mish(inplace=True),
)
self._blocks = nn.ModuleList([])
for args in _block_args: