diff options
Diffstat (limited to 'src/text_recognizer/networks/residual_network.py')
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 35 |
1 files changed, 17 insertions, 18 deletions
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 47e351a..1b5d6b3 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -8,6 +8,7 @@ from torch import nn from torch import Tensor from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.stn import SpatialTransformerNetwork class Conv2dAuto(nn.Conv2d): @@ -197,25 +198,28 @@ class ResidualLayer(nn.Module): return x -class Encoder(nn.Module): +class ResidualNetworkEncoder(nn.Module): """Encoder network.""" def __init__( self, in_channels: int = 1, - block_sizes: List[int] = (32, 64), - depths: List[int] = (2, 2), + block_sizes: Union[int, List[int]] = (32, 64), + depths: Union[int, List[int]] = (2, 2), activation: str = "relu", block: Type[nn.Module] = BasicBlock, + levels: int = 1, + stn: bool = False, *args, **kwargs ) -> None: super().__init__() - - self.block_sizes = block_sizes - self.depths = depths + self.stn = SpatialTransformerNetwork() if stn else None + self.block_sizes = ( + block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels + ) + self.depths = depths if isinstance(depths, list) else [depths] * levels self.activation = activation - self.gate = nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -227,7 +231,7 @@ class Encoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -271,11 +275,13 @@ class Encoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) + if self.stn is not None: + x = self.stn(x) x = self.gate(x) return self.blocks(x) -class Decoder(nn.Module): +class ResidualNetworkDecoder(nn.Module): """Classification head.""" def __init__(self, in_features: int, num_classes: int = 80) -> None: @@ -295,19 +301,12 @@ class ResidualNetwork(nn.Module): def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: super().__init__() - self.encoder = Encoder(in_channels, *args, **kwargs) - self.decoder = Decoder( + self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) + self.decoder = ResidualNetworkDecoder( in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, num_classes=num_classes, ) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - def forward(self, x: Tensor) -> Tensor: """Forward pass.""" x = self.encoder(x) |