summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/line_lstm_ctc.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/line_lstm_ctc.py')
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 988b615..5c57479 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module):
self.hidden_size = hidden_size
self.encoder = self._configure_encoder(encoder)
self.flatten = flatten
+ self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size)
self.rnn = nn.LSTM(
- input_size=self.input_size,
+ input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_layers=num_layers,
)
@@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module):
# Avgerage pooling.
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
+ # Linear layer between CNN and RNN
+ x = self.fc(x)
+
# Sequence predictions.
x, _ = self.rnn(x)