summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/crnn.py
diff options
context:
space:
mode:
authoraktersnurra <grydholm@kth.se>2020-12-02 23:48:52 +0100
committeraktersnurra <grydholm@kth.se>2020-12-02 23:48:52 +0100
commit5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (patch)
treef2be992554e278857db7d56786dba54a76d439c7 /src/text_recognizer/networks/crnn.py
parente3b039c9adb4bce42ede4cb682a3ae71e797539a (diff)
parent8e3985c9cde6666e4314973312135ec1c7a025b9 (diff)
Merge branch 'master' of github.com:aktersnurra/text-recognizer
Diffstat (limited to 'src/text_recognizer/networks/crnn.py')
-rw-r--r--src/text_recognizer/networks/crnn.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
index 9747429..778e232 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -1,4 +1,4 @@
-"""LSTM with CTC for handwritten text recognition within a line."""
+"""CRNN for handwritten text recognition."""
from typing import Dict, Tuple
from einops import rearrange, reduce
@@ -89,20 +89,22 @@ class ConvolutionalRecurrentNetwork(nn.Module):
x = self.backbone(x)
- # Avgerage pooling.
+ # Average pooling.
if self.avg_pool:
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
else:
x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
else:
# Encode the entire image with a CNN, and use the channels as temporal dimension.
- b = x.shape[0]
x = self.backbone(x)
- x = rearrange(x, "b c h w -> c b (h w)", b=b)
+ x = rearrange(x, "b c h w -> b w c h")
+ if self.adaptive_pool is not None:
+ x = self.adaptive_pool(x)
+ x = x.squeeze(3)
# Sequence predictions.
x, _ = self.rnn(x)
- # Sequence to classifcation layer.
+ # Sequence to classification layer.
x = self.decoder(x)
return x