summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/efficientnet.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:31:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:31:48 +0200
commit14760428dd457f3749c6513ad34b822b05d6a742 (patch)
treee985f7e310e92043306d14d21d3f2d7bc8930772 /text_recognizer/networks/efficientnet/efficientnet.py
parentb8372664412622cc0c35eeaec0ebce6cf3c0f03c (diff)
Fix efficientnet
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 2260ee2..2f2508d 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -31,11 +31,11 @@ class EfficientNet(nn.Module):
def __init__(
self,
arch: str,
- params: Tuple[float, float, float],
stochastic_dropout_rate: float = 0.2,
bn_momentum: float = 0.99,
bn_eps: float = 1.0e-3,
depth: int = 7,
+ out_channels: int = 1280,
) -> None:
super().__init__()
self.params = self._get_arch_params(arch)
@@ -43,7 +43,7 @@ class EfficientNet(nn.Module):
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.depth = depth
- self.out_channels: int
+ self.out_channels: int = out_channels
self._conv_stem: nn.Sequential
self._blocks: nn.ModuleList
self._conv_head: nn.Sequential
@@ -92,7 +92,6 @@ class EfficientNet(nn.Module):
args.stride = 1
in_channels = round_filters(_block_args[-1].out_channels, self.params)
- self.out_channels = round_filters(1280, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(
in_channels, self.out_channels, kernel_size=1, stride=1, bias=False