summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/mbconv.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet/mbconv.py')
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py267
1 files changed, 0 insertions, 267 deletions
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
deleted file mode 100644
index 9c97925..0000000
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ /dev/null
@@ -1,267 +0,0 @@
-"""Mobile inverted residual block."""
-from typing import Optional, Tuple, Union
-
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.efficientnet.utils import stochastic_depth
-
-
-class BaseModule(nn.Module):
- """Base sub module class."""
-
- def __init__(self, bn_momentum: float, bn_eps: float) -> None:
- super().__init__()
-
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self._build()
-
- def _build(self) -> None:
- pass
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- return self.block(x)
-
-
-class InvertedBottleneck(BaseModule):
- """Inverted bottleneck module."""
-
- def __init__(
- self,
- bn_momentum: float,
- bn_eps: float,
- in_channels: int,
- out_channels: int,
- ) -> None:
- self.in_channels = in_channels
- self.out_channels = out_channels
- super().__init__(bn_momentum, bn_eps)
-
- def _build(self) -> None:
- self.block = nn.Sequential(
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=1,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=self.out_channels,
- momentum=self.bn_momentum,
- eps=self.bn_eps,
- ),
- nn.Mish(inplace=True),
- )
-
-
-class Depthwise(BaseModule):
- """Depthwise convolution module."""
-
- def __init__(
- self,
- bn_momentum: float,
- bn_eps: float,
- channels: int,
- kernel_size: int,
- stride: int,
- ) -> None:
- self.channels = channels
- self.kernel_size = kernel_size
- self.stride = stride
- super().__init__(bn_momentum, bn_eps)
-
- def _build(self) -> None:
- self.block = nn.Sequential(
- nn.Conv2d(
- in_channels=self.channels,
- out_channels=self.channels,
- kernel_size=self.kernel_size,
- stride=self.stride,
- groups=self.channels,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps
- ),
- nn.Mish(inplace=True),
- )
-
-
-class SqueezeAndExcite(BaseModule):
- """Sequeeze and excite module."""
-
- def __init__(
- self,
- bn_momentum: float,
- bn_eps: float,
- in_channels: int,
- channels: int,
- se_ratio: float,
- ) -> None:
- self.in_channels = in_channels
- self.channels = channels
- self.se_ratio = se_ratio
- super().__init__(bn_momentum, bn_eps)
-
- def _build(self) -> None:
- num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
- self.block = nn.Sequential(
- nn.Conv2d(
- in_channels=self.channels,
- out_channels=num_squeezed_channels,
- kernel_size=1,
- ),
- nn.Mish(inplace=True),
- nn.Conv2d(
- in_channels=num_squeezed_channels,
- out_channels=self.channels,
- kernel_size=1,
- ),
- )
-
-
-class Pointwise(BaseModule):
- """Pointwise module."""
-
- def __init__(
- self,
- bn_momentum: float,
- bn_eps: float,
- in_channels: int,
- out_channels: int,
- ) -> None:
- self.in_channels = in_channels
- self.out_channels = out_channels
- super().__init__(bn_momentum, bn_eps)
-
- def _build(self) -> None:
- self.block = nn.Sequential(
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=1,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=self.out_channels,
- momentum=self.bn_momentum,
- eps=self.bn_eps,
- ),
- )
-
-
-class MBConvBlock(nn.Module):
- """Mobile Inverted Residual Bottleneck block."""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Tuple[int, int],
- stride: Tuple[int, int],
- bn_momentum: float,
- bn_eps: float,
- se_ratio: float,
- expand_ratio: int,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self.se_ratio = se_ratio
- self.pad = self._configure_padding()
- self.expand_ratio = expand_ratio
- self._inverted_bottleneck: Optional[InvertedBottleneck]
- self._depthwise: nn.Sequential
- self._squeeze_excite: nn.Sequential
- self._pointwise: nn.Sequential
- self._build()
-
- def _configure_padding(self) -> Tuple[int, int, int, int]:
- """Set padding for convolutional layers."""
- if self.stride == (2, 2):
- return (
- (self.kernel_size - 1) // 2 - 1,
- (self.kernel_size - 1) // 2,
- ) * 2
- return ((self.kernel_size - 1) // 2,) * 4
-
- def _build(self) -> None:
- has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
- inner_channels = self.in_channels * self.expand_ratio
- self._inverted_bottleneck = (
- InvertedBottleneck(
- in_channels=self.in_channels,
- out_channels=inner_channels,
- bn_momentum=self.bn_momentum,
- bn_eps=self.bn_eps,
- )
- if self.expand_ratio != 1
- else None
- )
-
- self._depthwise = Depthwise(
- channels=inner_channels,
- kernel_size=self.kernel_size,
- stride=self.stride,
- bn_momentum=self.bn_momentum,
- bn_eps=self.bn_eps,
- )
-
- self._squeeze_excite = (
- SqueezeAndExcite(
- in_channels=self.in_channels,
- channels=inner_channels,
- se_ratio=self.se_ratio,
- bn_momentum=self.bn_momentum,
- bn_eps=self.bn_eps,
- )
- if has_se
- else None
- )
-
- self._pointwise = Pointwise(
- in_channels=inner_channels,
- out_channels=self.out_channels,
- bn_momentum=self.bn_momentum,
- bn_eps=self.bn_eps,
- )
-
- def _stochastic_depth(
- self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
- ) -> Tensor:
- if self.stride == (1, 1) and self.in_channels == self.out_channels:
- if stochastic_dropout_rate:
- x = stochastic_depth(
- x, p=stochastic_dropout_rate, training=self.training
- )
- x += residual
- return x
-
- def forward(
- self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
- ) -> Tensor:
- """Forward pass."""
- residual = x
- if self._inverted_bottleneck is not None:
- x = self._inverted_bottleneck(x)
-
- x = F.pad(x, self.pad)
- x = self._depthwise(x)
-
- if self._squeeze_excite is not None:
- x_squeezed = F.adaptive_avg_pool2d(x, 1)
- x_squeezed = self._squeeze_excite(x)
- x = torch.tanh(F.softplus(x_squeezed)) * x
-
- x = self._pointwise(x)
-
- # Stochastic depth
- x = self._stochastic_depth(x, residual, stochastic_dropout_rate)
- return x