From 22e36513dd43d2e2ca82ca28a1ea757c5663676a Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 26 Jun 2021 00:35:02 +0200
Subject: Updates

---
 text_recognizer/networks/encoders/efficientnet/efficientnet.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 283b686..a59abf8 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -27,6 +27,7 @@ class EfficientNet(nn.Module):
     def __init__(
         self,
         arch: str,
+        out_channels: int = 256,
         stochastic_dropout_rate: float = 0.2,
         bn_momentum: float = 0.99,
         bn_eps: float = 1.0e-3,
@@ -34,6 +35,7 @@ class EfficientNet(nn.Module):
         super().__init__()
         assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
         self.arch = self.archs[arch]
+        self.out_channels = out_channels
         self.stochastic_dropout_rate = stochastic_dropout_rate
         self.bn_momentum = 1 - bn_momentum
         self.bn_eps = bn_eps
@@ -77,7 +79,7 @@ class EfficientNet(nn.Module):
                 args.stride = 1
 
         in_channels = round_filters(320, self.arch)
-        out_channels = round_filters(1280, self.arch)
+        out_channels = round_filters(self.out_channels, self.arch)
         self._conv_head = nn.Sequential(
             nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
             nn.BatchNorm2d(
-- 
cgit v1.2.3-70-g09d2