From 3b06ef615a8db67a03927576e0c12fbfb2501f5f Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 14 Sep 2020 22:15:47 +0200 Subject: Fixed CTC loss. --- src/text_recognizer/networks/line_lstm_ctc.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'src/text_recognizer/networks/line_lstm_ctc.py') diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 988b615..5c57479 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module): self.hidden_size = hidden_size self.encoder = self._configure_encoder(encoder) self.flatten = flatten + self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size) self.rnn = nn.LSTM( - input_size=self.input_size, + input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=num_layers, ) @@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module): # Avgerage pooling. x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x + # Linear layer between CNN and RNN + x = self.fc(x) + # Sequence predictions. x, _ = self.rnn(x) -- cgit v1.2.3-70-g09d2