summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/residual_network.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/residual_network.py')
-rw-r--r--src/text_recognizer/networks/residual_network.py7
1 files changed, 1 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 6405192..e397224 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,7 +7,6 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.stn import SpatialTransformerNetwork
from text_recognizer.networks.util import activation_function
@@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module):
activation: str = "relu",
block: Type[nn.Module] = BasicBlock,
levels: int = 1,
- stn: bool = False,
*args,
**kwargs
) -> None:
super().__init__()
- self.stn = SpatialTransformerNetwork() if stn else None
self.block_sizes = (
block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
)
@@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module):
),
nn.BatchNorm2d(self.block_sizes[0]),
activation_function(self.activation),
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+ # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
)
self.blocks = self._configure_blocks(block)
@@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module):
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
- if self.stn is not None:
- x = self.stn(x)
x = self.gate(x)
x = self.blocks(x)
return x