summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py68
1 files changed, 34 insertions, 34 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 98d58fd..b527d90 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,13 +1,9 @@
"""Efficient net."""
-from typing import Tuple
-
from torch import nn, Tensor
from .mbconv import MBConvBlock
from .utils import (
block_args,
- calculate_output_image_size,
- get_same_padding_conv2d,
round_filters,
round_repeats,
)
@@ -28,11 +24,19 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- def __init__(self, arch: str, image_size: Tuple[int, int]) -> None:
+ def __init__(
+ self,
+ arch: str,
+ stochastic_dropout_rate: float = 0.2,
+ bn_momentum: float = 0.99,
+ bn_eps: float = 1.0e-3,
+ ) -> None:
super().__init__()
assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
self.arch = self.archs[arch]
- self.image_size = image_size
+ self.stochastic_dropout_rate = stochastic_dropout_rate
+ self.bn_momentum = 1 - bn_momentum
+ self.bn_eps = bn_eps
self._conv_stem: nn.Sequential = None
self._blocks: nn.Sequential = None
self._conv_head: nn.Sequential = None
@@ -42,57 +46,53 @@ class EfficientNet(nn.Module):
_block_args = block_args()
in_channels = 1 # BW
out_channels = round_filters(32, self.arch)
- Conv2d = get_same_padding_conv2d(image_size=self.image_size)
self._conv_stem = nn.Sequential(
- Conv2d(
+ nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
- stride=2,
+ stride=(2, 2),
bias=False,
),
- nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps),
+ nn.BatchNorm2d(
+ num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ ),
nn.Mish(inplace=True),
)
- image_size = calculate_output_image_size(self.image_size, 2)
self._blocks = nn.ModuleList([])
for args in _block_args:
args.in_channels = round_filters(args.in_channels, self.arch)
args.out_channels = round_filters(args.out_channels, self.arch)
- args.num_repeat = round_repeats(args.num_repeat, self.arch)
-
- self._blocks.append(
- MBConvBlock(
- **args,
- bn_momentum=bn_momentum,
- bn_eps=bn_eps,
- image_size=image_size,
- )
- )
- image_size = calculate_output_image_size(image_size, args.stride)
- if args.num_repeat > 1:
- args.in_channels = args.out_channels
- args.stride = 1
- for _ in range(args.num_repeat - 1):
+ args.num_repeats = round_repeats(args.num_repeats, self.arch)
+ for _ in range(args.num_repeats):
self._blocks.append(
MBConvBlock(
**args,
- bn_momentum=bn_momentum,
- bn_eps=bn_eps,
- image_size=image_size,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
)
)
+ args.in_channels = args.out_channels
+ args.stride = 1
- in_channels = args.out_channels
+ in_channels = round_filters(320, self.arch)
out_channels = round_filters(1280, self.arch)
- Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_head = nn.Sequential(
- Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
- nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps),
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+ nn.BatchNorm2d(
+ num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ ),
)
def extract_features(self, x: Tensor) -> Tensor:
x = self._conv_stem(x)
+ for i, block in enumerate(self._blocks):
+ stochastic_dropout_rate = self.stochastic_dropout_rate
+ if self.stochastic_dropout_rate:
+ stochastic_dropout_rate *= i / len(self._blocks)
+ x = block(x, stochastic_dropout_rate=stochastic_dropout_rate)
+ self._conv_head(x)
+ return x
def forward(self, x: Tensor) -> Tensor:
- pass
+ return self.extract_features(x)