diff options
author | aktersnurra <grydholm@kth.se> | 2021-06-14 23:21:21 +0200 |
---|---|---|
committer | aktersnurra <grydholm@kth.se> | 2021-06-14 23:21:21 +0200 |
commit | fd971d09fd6167ac42bd5aeb5e64a719dc1c370b (patch) | |
tree | 001d7b92ee5d5bfcb60b186e8e13705c88a3db32 /text_recognizer/networks/encoders/efficient_net | |
parent | 57525da0f267300792cd6b65e59914644a2dd39b (diff) |
added part of forward pass in mbconv
Diffstat (limited to 'text_recognizer/networks/encoders/efficient_net')
-rw-r--r-- | text_recognizer/networks/encoders/efficient_net/block.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficient_net/block.py b/text_recognizer/networks/encoders/efficient_net/block.py index 4302bbd..d9a0416 100644 --- a/text_recognizer/networks/encoders/efficient_net/block.py +++ b/text_recognizer/networks/encoders/efficient_net/block.py @@ -3,6 +3,7 @@ from typing import Tuple import torch from torch import nn, Tensor +from torch.nn import functional as F from .utils import get_same_padding_conv2d @@ -141,3 +142,19 @@ class MBConvBlock(nn.Module): num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), ) + + def forward(self, x: Tensor, drop_connection_rate: Optional[float]) -> Tensor: + residual = x + if self._inverted_bottleneck is not None: + x = self._inverted_bottleneck(x) + + x = self._depthwise(x) + + if self._squeeze_excite is not None: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._squeeze_excite(x) + x = torch.sigmoid(x_squeezed) * x + + x = self._pointwise(x) + + # Stochastic depth |