summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 2260ee2..2f2508d 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -31,11 +31,11 @@ class EfficientNet(nn.Module):
def __init__(
self,
arch: str,
- params: Tuple[float, float, float],
stochastic_dropout_rate: float = 0.2,
bn_momentum: float = 0.99,
bn_eps: float = 1.0e-3,
depth: int = 7,
+ out_channels: int = 1280,
) -> None:
super().__init__()
self.params = self._get_arch_params(arch)
@@ -43,7 +43,7 @@ class EfficientNet(nn.Module):
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.depth = depth
- self.out_channels: int
+ self.out_channels: int = out_channels
self._conv_stem: nn.Sequential
self._blocks: nn.ModuleList
self._conv_head: nn.Sequential
@@ -92,7 +92,6 @@ class EfficientNet(nn.Module):
args.stride = 1
in_channels = round_filters(_block_args[-1].out_channels, self.params)
- self.out_channels = round_filters(1280, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(
in_channels, self.out_channels, kernel_size=1, stride=1, bias=False