diff options
Diffstat (limited to 'src/text_recognizer/networks/wide_resnet.py')
-rw-r--r-- | src/text_recognizer/networks/wide_resnet.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index 28f3380..b767778 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module): dropout_rate: float = 0.0, num_layers: int = 3, block: Type[nn.Module] = WideBlock, + num_stages: Optional[List[int]] = None, activation: str = "relu", use_decoder: bool = True, ) -> None: @@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module): dropout_rate (float): The dropout rate. Defaults to 0.0. num_layers (int): Number of layers of blocks. Defaults to 3. block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + num_stages (List[int]): If given, will use these channel values. Defaults to None. activation (str): Name of the activation to use. Defaults to "relu". use_decoder (bool): If True, the network output character predictions, if False, the network outputs a latent vector. Defaults to True. @@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module): self.dropout_rate = dropout_rate self.activation = activation_function(activation) - self.num_stages = [self.in_planes] + [ - self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers) - ] + if num_stages is None: + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor + for n in range(self.num_layers) + ] + else: + self.num_stages = [self.in_planes] + num_stages + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) self.strides = [1] + [2] * (self.num_layers - 1) |