diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-19 19:55:30 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-19 19:55:30 +0200 |
commit | 3cac3c32ff49618cfa05bb4e5630a62fc31af40a (patch) | |
tree | 1b4d7bb74cacb6fc1f4894b08d048ad305b7cef4 | |
parent | eac4f664719d38b467a33a496282cc6c4519be15 (diff) |
Add stride and extra layer to effnet
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 2f2508d..de08457 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -36,6 +36,7 @@ class EfficientNet(nn.Module): bn_eps: float = 1.0e-3, depth: int = 7, out_channels: int = 1280, + stride: Tuple[int, int] = (2, 2), ) -> None: super().__init__() self.params = self._get_arch_params(arch) @@ -43,6 +44,7 @@ class EfficientNet(nn.Module): self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.depth = depth + self.stride = stride self.out_channels: int = out_channels self._conv_stem: nn.Sequential self._blocks: nn.ModuleList @@ -66,7 +68,7 @@ class EfficientNet(nn.Module): in_channels=in_channels, out_channels=out_channels, kernel_size=3, - stride=(2, 2), + stride=self.stride, bias=False, ), nn.BatchNorm2d( @@ -94,7 +96,24 @@ class EfficientNet(nn.Module): in_channels = round_filters(_block_args[-1].out_channels, self.params) self._conv_head = nn.Sequential( nn.Conv2d( - in_channels, self.out_channels, kernel_size=1, stride=1, bias=False + in_channels, + self.out_channels, + kernel_size=2, + stride=self.stride, + bias=False, + ), + nn.BatchNorm2d( + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, + ), + nn.Mish(inplace=True), + nn.Conv2d( + self.out_channels, + self.out_channels, + kernel_size=1, + stride=1, + bias=False, ), nn.BatchNorm2d( num_features=self.out_channels, |