summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/utils.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/utils.py107
1 files changed, 5 insertions, 102 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
index ff52485..6f293db 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -1,27 +1,15 @@
"""Util functions for efficient net."""
from functools import partial
import math
-from typing import Any, Optional, Tuple, Type
+from typing import Any, Optional, Union, Tuple, Type
from omegaconf import OmegaConf
import torch
-from torch import nn, Tensor
-import torch.functional as F
+from torch import Tensor
-def calculate_output_image_size(
- image_size: Optional[Tuple[int, int]], stride: int
-) -> Optional[Tuple[int, int]]:
- """Calculates the output image size when using conv2d with same padding."""
- if image_size is None:
- return None
- height = int(math.ceil(image_size[0] / stride))
- width = int(math.ceil(image_size[1] / stride))
- return height, width
-
-
-def drop_connection(x: Tensor, p: float, training: bool) -> Tensor:
- """Drop connection.
+def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
+ """Stochastic connection.
Drops the entire convolution with a given survival probability.
@@ -57,91 +45,6 @@ def drop_connection(x: Tensor, p: float, training: bool) -> Tensor:
return out
-def get_same_padding_conv2d(image_size: Optional[Tuple[int, int]]) -> Type[nn.Conv2d]:
- if image_size is None:
- return Conv2dDynamicSamePadding
- return partial(Conv2dStaticSamePadding, image_size=image_size)
-
-
-class Conv2dDynamicSamePadding(nn.Conv2d):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int = 1,
- dilation: int = 1,
- groups: int = 1,
- bias: bool = True,
- ) -> None:
- super().__init__(
- in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
- )
- self.stride = [self.stride] * 2
-
- def forward(self, x: Tensor) -> Tensor:
- ih, iw = x.shape[-2:]
- kh, kw = self.weight.shape[-2:]
- sh, sw = self.stride
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
- if pad_h > 0 or pad_w > 0:
- x = F.pad(
- x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
- )
- return F.conv2d(
- x,
- self.weight,
- self.bias,
- self.stride,
- self.padding,
- self.dilation,
- self.groups,
- )
-
-
-class Conv2dStaticSamePadding(nn.Conv2d):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- image_size: Tuple[int, int],
- stride: int = 1,
- **kwargs: Any
- ):
- super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
- self.stride = [self.stride] * 2
-
- # Calculate padding based on image size and save it.
- ih, iw = image_size
- kh, kw = self.weight.shape[-2:]
- sh, sw = self.stride
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
- if pad_h > 0 or pad_w > 0:
- self.static_padding = nn.ZeroPad2d(
- (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
- )
- else:
- self.static_padding = nn.Identity()
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.static_padding(x)
- x = F.pad(
- x,
- self.weight,
- self.bias,
- self.stride,
- self.padding,
- self.dilation,
- self.groups,
- )
- return x
-
-
def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
multiplier = arch[0]
divisor = 8
@@ -160,7 +63,7 @@ def block_args():
keys = [
"num_repeats",
"kernel_size",
- "strides",
+ "stride",
"expand_ratio",
"in_channels",
"out_channels",