summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/wide_resnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/wide_resnet.py')
-rw-r--r--src/text_recognizer/networks/wide_resnet.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 28f3380..b767778 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module):
dropout_rate: float = 0.0,
num_layers: int = 3,
block: Type[nn.Module] = WideBlock,
+ num_stages: Optional[List[int]] = None,
activation: str = "relu",
use_decoder: bool = True,
) -> None:
@@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module):
dropout_rate (float): The dropout rate. Defaults to 0.0.
num_layers (int): Number of layers of blocks. Defaults to 3.
block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ num_stages (List[int]): If given, will use these channel values. Defaults to None.
activation (str): Name of the activation to use. Defaults to "relu".
use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
latent vector. Defaults to True.
@@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module):
self.dropout_rate = dropout_rate
self.activation = activation_function(activation)
- self.num_stages = [self.in_planes] + [
- self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers)
- ]
+ if num_stages is None:
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor
+ for n in range(self.num_layers)
+ ]
+ else:
+ self.num_stages = [self.in_planes] + num_stages
+
self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
self.strides = [1] + [2] * (self.num_layers - 1)