summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/residual_network.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/residual_network.py')
-rw-r--r--src/text_recognizer/networks/residual_network.py35
1 files changed, 17 insertions, 18 deletions
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 47e351a..1b5d6b3 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -8,6 +8,7 @@ from torch import nn
from torch import Tensor
from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.stn import SpatialTransformerNetwork
class Conv2dAuto(nn.Conv2d):
@@ -197,25 +198,28 @@ class ResidualLayer(nn.Module):
return x
-class Encoder(nn.Module):
+class ResidualNetworkEncoder(nn.Module):
"""Encoder network."""
def __init__(
self,
in_channels: int = 1,
- block_sizes: List[int] = (32, 64),
- depths: List[int] = (2, 2),
+ block_sizes: Union[int, List[int]] = (32, 64),
+ depths: Union[int, List[int]] = (2, 2),
activation: str = "relu",
block: Type[nn.Module] = BasicBlock,
+ levels: int = 1,
+ stn: bool = False,
*args,
**kwargs
) -> None:
super().__init__()
-
- self.block_sizes = block_sizes
- self.depths = depths
+ self.stn = SpatialTransformerNetwork() if stn else None
+ self.block_sizes = (
+ block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
+ )
+ self.depths = depths if isinstance(depths, list) else [depths] * levels
self.activation = activation
-
self.gate = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -227,7 +231,7 @@ class Encoder(nn.Module):
),
nn.BatchNorm2d(self.block_sizes[0]),
activation_function(self.activation),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
)
self.blocks = self._configure_blocks(block)
@@ -271,11 +275,13 @@ class Encoder(nn.Module):
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
+ if self.stn is not None:
+ x = self.stn(x)
x = self.gate(x)
return self.blocks(x)
-class Decoder(nn.Module):
+class ResidualNetworkDecoder(nn.Module):
"""Classification head."""
def __init__(self, in_features: int, num_classes: int = 80) -> None:
@@ -295,19 +301,12 @@ class ResidualNetwork(nn.Module):
def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
super().__init__()
- self.encoder = Encoder(in_channels, *args, **kwargs)
- self.decoder = Decoder(
+ self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
+ self.decoder = ResidualNetworkDecoder(
in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
num_classes=num_classes,
)
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
x = self.encoder(x)