summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/wide_resnet.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
commit3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch)
treee1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/networks/wide_resnet.py
parent2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff)
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/networks/wide_resnet.py')
-rw-r--r--src/text_recognizer/networks/wide_resnet.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index d1c8f9a..618f414 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -28,10 +28,10 @@ def conv_init(module: Type[nn.Module]) -> None:
classname = module.__class__.__name__
if classname.find("Conv") != -1:
nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
- nn.init.constant(module.bias, 0)
+ nn.init.constant_(module.bias, 0)
elif classname.find("BatchNorm") != -1:
- nn.init.constant(module.weight, 1)
- nn.init.constant(module.bias, 0)
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
class WideBlock(nn.Module):
@@ -183,7 +183,7 @@ class WideResidualNetwork(nn.Module):
else None
)
- self.apply(conv_init)
+ # self.apply(conv_init)
def _configure_wide_layer(
self, in_planes: int, out_planes: int, stride: int, activation: str