diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
commit | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch) | |
tree | e1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/networks | |
parent | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff) |
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/ctc.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/networks/misc.py | 1 | ||||
-rw-r--r-- | src/text_recognizer/networks/transformer.py | 1 | ||||
-rw-r--r-- | src/text_recognizer/networks/wide_resnet.py | 8 |
5 files changed, 12 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 72f18b8..2493d5c 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -24,7 +24,7 @@ def greedy_decoder( target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults to None. - blank_label (int): The blank character to be ignored. Defaults to 79. + blank_label (int): The blank character to be ignored. Defaults to 80. collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. Returns: diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 988b615..5c57479 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module): self.hidden_size = hidden_size self.encoder = self._configure_encoder(encoder) self.flatten = flatten + self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size) self.rnn = nn.LSTM( - input_size=self.input_size, + input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=num_layers, ) @@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module): # Avgerage pooling. x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x + # Linear layer between CNN and RNN + x = self.fc(x) + # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index cac9e78..1f853e9 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -34,6 +34,7 @@ def activation_function(activation: str) -> Type[nn.Module]: """Returns the callable activation function.""" activation_fns = nn.ModuleDict( [ + ["elu", nn.ELU(inplace=True)], ["gelu", nn.GELU()], ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], ["none", nn.Identity()], diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/text_recognizer/networks/transformer.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index d1c8f9a..618f414 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -28,10 +28,10 @@ def conv_init(module: Type[nn.Module]) -> None: classname = module.__class__.__name__ if classname.find("Conv") != -1: nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) - nn.init.constant(module.bias, 0) + nn.init.constant_(module.bias, 0) elif classname.find("BatchNorm") != -1: - nn.init.constant(module.weight, 1) - nn.init.constant(module.bias, 0) + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) class WideBlock(nn.Module): @@ -183,7 +183,7 @@ class WideResidualNetwork(nn.Module): else None ) - self.apply(conv_init) + # self.apply(conv_init) def _configure_wide_layer( self, in_planes: int, out_planes: int, stride: int, activation: str |