summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 283b686..a59abf8 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -27,6 +27,7 @@ class EfficientNet(nn.Module):
def __init__(
self,
arch: str,
+ out_channels: int = 256,
stochastic_dropout_rate: float = 0.2,
bn_momentum: float = 0.99,
bn_eps: float = 1.0e-3,
@@ -34,6 +35,7 @@ class EfficientNet(nn.Module):
super().__init__()
assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
self.arch = self.archs[arch]
+ self.out_channels = out_channels
self.stochastic_dropout_rate = stochastic_dropout_rate
self.bn_momentum = 1 - bn_momentum
self.bn_eps = bn_eps
@@ -77,7 +79,7 @@ class EfficientNet(nn.Module):
args.stride = 1
in_channels = round_filters(320, self.arch)
- out_channels = round_filters(1280, self.arch)
+ out_channels = round_filters(self.out_channels, self.arch)
self._conv_head = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(