summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py23
1 files changed, 21 insertions, 2 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 2f2508d..de08457 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -36,6 +36,7 @@ class EfficientNet(nn.Module):
bn_eps: float = 1.0e-3,
depth: int = 7,
out_channels: int = 1280,
+ stride: Tuple[int, int] = (2, 2),
) -> None:
super().__init__()
self.params = self._get_arch_params(arch)
@@ -43,6 +44,7 @@ class EfficientNet(nn.Module):
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.depth = depth
+ self.stride = stride
self.out_channels: int = out_channels
self._conv_stem: nn.Sequential
self._blocks: nn.ModuleList
@@ -66,7 +68,7 @@ class EfficientNet(nn.Module):
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
- stride=(2, 2),
+ stride=self.stride,
bias=False,
),
nn.BatchNorm2d(
@@ -94,7 +96,24 @@ class EfficientNet(nn.Module):
in_channels = round_filters(_block_args[-1].out_channels, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(
- in_channels, self.out_channels, kernel_size=1, stride=1, bias=False
+ in_channels,
+ self.out_channels,
+ kernel_size=2,
+ stride=self.stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
+ ),
+ nn.Mish(inplace=True),
+ nn.Conv2d(
+ self.out_channels,
+ self.out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
),
nn.BatchNorm2d(
num_features=self.out_channels,