diff options
Diffstat (limited to 'src/text_recognizer/networks/residual_network.py')
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 6405192..e397224 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -7,7 +7,6 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.stn import SpatialTransformerNetwork from text_recognizer.networks.util import activation_function @@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module): activation: str = "relu", block: Type[nn.Module] = BasicBlock, levels: int = 1, - stn: bool = False, *args, **kwargs ) -> None: super().__init__() - self.stn = SpatialTransformerNetwork() if stn else None self.block_sizes = ( block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels ) @@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) - if self.stn is not None: - x = self.stn(x) x = self.gate(x) x = self.blocks(x) return x |