summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:06:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:06:37 +0200
commitcfb460666953c87f606833bf597b53eba0a2900d (patch)
treeece2fcb5c9003a4dac8ed4f2108d063b07cdc274 /text_recognizer/networks/efficientnet
parentf95d51e45ea24a956ce4384e4680f849651b2506 (diff)
Format
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py11
-rw-r--r--text_recognizer/networks/efficientnet/utils.py10
2 files changed, 5 insertions, 16 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index bd47e4b..3481090 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -73,17 +73,6 @@ class EfficientNet(nn.Module):
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
nn.Mish(inplace=True),
- nn.Conv2d(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=3,
- stride=2,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
- ),
- nn.Mish(inplace=True),
)
self._blocks = nn.ModuleList([])
for args in _block_args:
diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
index 412d07d..5234324 100644
--- a/text_recognizer/networks/efficientnet/utils.py
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -74,11 +74,11 @@ def block_args() -> List[DictConfig]:
args = [
[1, 3, (1, 1), 1, 32, 16, 0.25],
[2, 3, (2, 2), 6, 16, 24, 0.25],
- [2, 5, (2, 1), 6, 24, 40, 0.25],
- [3, 3, (2, 1), 6, 40, 80, 0.25],
- [3, 5, (2, 1), 6, 80, 112, 0.25],
- [4, 5, (1, 1), 6, 112, 192, 0.25],
- [1, 3, (2, 1), 6, 192, 320, 0.25],
+ [2, 5, (2, 2), 6, 24, 40, 0.25],
+ [3, 3, (2, 2), 6, 40, 80, 0.25],
+ [3, 5, (1, 1), 6, 80, 112, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
]
block_args_ = []
for row in args: