diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/encoders/efficient_net/block.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficient_net/block.py b/text_recognizer/networks/encoders/efficient_net/block.py index 4302bbd..d9a0416 100644 --- a/text_recognizer/networks/encoders/efficient_net/block.py +++ b/text_recognizer/networks/encoders/efficient_net/block.py @@ -3,6 +3,7 @@ from typing import Tuple import torch from torch import nn, Tensor +from torch.nn import functional as F from .utils import get_same_padding_conv2d @@ -141,3 +142,19 @@ class MBConvBlock(nn.Module): num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), ) + + def forward(self, x: Tensor, drop_connection_rate: Optional[float]) -> Tensor: + residual = x + if self._inverted_bottleneck is not None: + x = self._inverted_bottleneck(x) + + x = self._depthwise(x) + + if self._squeeze_excite is not None: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._squeeze_excite(x) + x = torch.sigmoid(x_squeezed) * x + + x = self._pointwise(x) + + # Stochastic depth |