diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 2f2508d..de08457 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -36,6 +36,7 @@ class EfficientNet(nn.Module): bn_eps: float = 1.0e-3, depth: int = 7, out_channels: int = 1280, + stride: Tuple[int, int] = (2, 2), ) -> None: super().__init__() self.params = self._get_arch_params(arch) @@ -43,6 +44,7 @@ class EfficientNet(nn.Module): self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.depth = depth + self.stride = stride self.out_channels: int = out_channels self._conv_stem: nn.Sequential self._blocks: nn.ModuleList @@ -66,7 +68,7 @@ class EfficientNet(nn.Module): in_channels=in_channels, out_channels=out_channels, kernel_size=3, - stride=(2, 2), + stride=self.stride, bias=False, ), nn.BatchNorm2d( @@ -94,7 +96,24 @@ class EfficientNet(nn.Module): in_channels = round_filters(_block_args[-1].out_channels, self.params) self._conv_head = nn.Sequential( nn.Conv2d( - in_channels, self.out_channels, kernel_size=1, stride=1, bias=False + in_channels, + self.out_channels, + kernel_size=2, + stride=self.stride, + bias=False, + ), + nn.BatchNorm2d( + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, + ), + nn.Mish(inplace=True), + nn.Conv2d( + self.out_channels, + self.out_channels, + kernel_size=1, + stride=1, + bias=False, ), nn.BatchNorm2d( num_features=self.out_channels, |