summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-26 00:35:02 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-26 00:35:02 +0200
commit22e36513dd43d2e2ca82ca28a1ea757c5663676a (patch)
tree54285c3c30a02b00af989078bf61c122b9eccabd /text_recognizer/networks/encoders/efficientnet
parent9c3a8753d95ecb70a84e1eb40933590a510abfc4 (diff)
Updates
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 283b686..a59abf8 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -27,6 +27,7 @@ class EfficientNet(nn.Module):
def __init__(
self,
arch: str,
+ out_channels: int = 256,
stochastic_dropout_rate: float = 0.2,
bn_momentum: float = 0.99,
bn_eps: float = 1.0e-3,
@@ -34,6 +35,7 @@ class EfficientNet(nn.Module):
super().__init__()
assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
self.arch = self.archs[arch]
+ self.out_channels = out_channels
self.stochastic_dropout_rate = stochastic_dropout_rate
self.bn_momentum = 1 - bn_momentum
self.bn_eps = bn_eps
@@ -77,7 +79,7 @@ class EfficientNet(nn.Module):
args.stride = 1
in_channels = round_filters(320, self.arch)
- out_channels = round_filters(1280, self.arch)
+ out_channels = round_filters(self.out_channels, self.arch)
self._conv_head = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(