summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py17
1 files changed, 2 insertions, 15 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index de08457..2a712d8 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -61,14 +61,14 @@ class EfficientNet(nn.Module):
"""Builds the efficientnet backbone."""
_block_args = block_args()[: self.depth]
in_channels = 1 # BW
- out_channels = round_filters(32, self.params)
+ out_channels = round_filters(16, self.params)
self._conv_stem = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
- stride=self.stride,
+ stride=2,
bias=False,
),
nn.BatchNorm2d(
@@ -98,19 +98,6 @@ class EfficientNet(nn.Module):
nn.Conv2d(
in_channels,
self.out_channels,
- kernel_size=2,
- stride=self.stride,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=self.out_channels,
- momentum=self.bn_momentum,
- eps=self.bn_eps,
- ),
- nn.Mish(inplace=True),
- nn.Conv2d(
- self.out_channels,
- self.out_channels,
kernel_size=1,
stride=1,
bias=False,