summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/encoders/efficient_net/block.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficient_net/block.py b/text_recognizer/networks/encoders/efficient_net/block.py
index 4302bbd..d9a0416 100644
--- a/text_recognizer/networks/encoders/efficient_net/block.py
+++ b/text_recognizer/networks/encoders/efficient_net/block.py
@@ -3,6 +3,7 @@ from typing import Tuple
import torch
from torch import nn, Tensor
+from torch.nn import functional as F
from .utils import get_same_padding_conv2d
@@ -141,3 +142,19 @@ class MBConvBlock(nn.Module):
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
)
+
+ def forward(self, x: Tensor, drop_connection_rate: Optional[float]) -> Tensor:
+ residual = x
+ if self._inverted_bottleneck is not None:
+ x = self._inverted_bottleneck(x)
+
+ x = self._depthwise(x)
+
+ if self._squeeze_excite is not None:
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
+ x_squeezed = self._squeeze_excite(x)
+ x = torch.sigmoid(x_squeezed) * x
+
+ x = self._pointwise(x)
+
+ # Stochastic depth