summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
commit3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch)
treee1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/networks
parent2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff)
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/ctc.py2
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py6
-rw-r--r--src/text_recognizer/networks/misc.py1
-rw-r--r--src/text_recognizer/networks/transformer.py1
-rw-r--r--src/text_recognizer/networks/wide_resnet.py8
5 files changed, 12 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 72f18b8..2493d5c 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -24,7 +24,7 @@ def greedy_decoder(
target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
to None.
- blank_label (int): The blank character to be ignored. Defaults to 79.
+ blank_label (int): The blank character to be ignored. Defaults to 80.
collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
Returns:
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 988b615..5c57479 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module):
self.hidden_size = hidden_size
self.encoder = self._configure_encoder(encoder)
self.flatten = flatten
+ self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size)
self.rnn = nn.LSTM(
- input_size=self.input_size,
+ input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_layers=num_layers,
)
@@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module):
# Avgerage pooling.
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
+ # Linear layer between CNN and RNN
+ x = self.fc(x)
+
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index cac9e78..1f853e9 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -34,6 +34,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
"""Returns the callable activation function."""
activation_fns = nn.ModuleDict(
[
+ ["elu", nn.ELU(inplace=True)],
["gelu", nn.GELU()],
["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
["none", nn.Identity()],
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/text_recognizer/networks/transformer.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index d1c8f9a..618f414 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -28,10 +28,10 @@ def conv_init(module: Type[nn.Module]) -> None:
classname = module.__class__.__name__
if classname.find("Conv") != -1:
nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
- nn.init.constant(module.bias, 0)
+ nn.init.constant_(module.bias, 0)
elif classname.find("BatchNorm") != -1:
- nn.init.constant(module.weight, 1)
- nn.init.constant(module.bias, 0)
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
class WideBlock(nn.Module):
@@ -183,7 +183,7 @@ class WideResidualNetwork(nn.Module):
else None
)
- self.apply(conv_init)
+ # self.apply(conv_init)
def _configure_wide_layer(
self, in_planes: int, out_planes: int, stride: int, activation: str