summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoraktersnurra <grydholm@kth.se>2021-06-14 23:21:21 +0200
committeraktersnurra <grydholm@kth.se>2021-06-14 23:21:21 +0200
commitfd971d09fd6167ac42bd5aeb5e64a719dc1c370b (patch)
tree001d7b92ee5d5bfcb60b186e8e13705c88a3db32
parent57525da0f267300792cd6b65e59914644a2dd39b (diff)
added part of forward pass in mbconv
-rw-r--r--text_recognizer/networks/encoders/efficient_net/block.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficient_net/block.py b/text_recognizer/networks/encoders/efficient_net/block.py
index 4302bbd..d9a0416 100644
--- a/text_recognizer/networks/encoders/efficient_net/block.py
+++ b/text_recognizer/networks/encoders/efficient_net/block.py
@@ -3,6 +3,7 @@ from typing import Tuple
import torch
from torch import nn, Tensor
+from torch.nn import functional as F
from .utils import get_same_padding_conv2d
@@ -141,3 +142,19 @@ class MBConvBlock(nn.Module):
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
)
+
+ def forward(self, x: Tensor, drop_connection_rate: Optional[float]) -> Tensor:
+ residual = x
+ if self._inverted_bottleneck is not None:
+ x = self._inverted_bottleneck(x)
+
+ x = self._depthwise(x)
+
+ if self._squeeze_excite is not None:
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
+ x_squeezed = self._squeeze_excite(x)
+ x = torch.sigmoid(x_squeezed) * x
+
+ x = self._pointwise(x)
+
+ # Stochastic depth