diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-16 20:26:32 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-16 20:26:32 +0100 |
commit | f2cd16f340aa11afadb8fa90c29f85ca1b75a600 (patch) | |
tree | 8a7b32a9b2662b3b3bcec1119a9e6a25bb599cb9 /src/text_recognizer/networks/crnn.py | |
parent | 75909723fa2b1f6245d5c5422e4f2e88b8a26052 (diff) |
Added a whitening transform.
Diffstat (limited to 'src/text_recognizer/networks/crnn.py')
-rw-r--r-- | src/text_recognizer/networks/crnn.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py index 9747429..778e232 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/src/text_recognizer/networks/crnn.py @@ -1,4 +1,4 @@ -"""LSTM with CTC for handwritten text recognition within a line.""" +"""CRNN for handwritten text recognition.""" from typing import Dict, Tuple from einops import rearrange, reduce @@ -89,20 +89,22 @@ class ConvolutionalRecurrentNetwork(nn.Module): x = self.backbone(x) - # Avgerage pooling. + # Average pooling. if self.avg_pool: x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) else: x = rearrange(x, "(b t) h -> t b h", b=b, t=t) else: # Encode the entire image with a CNN, and use the channels as temporal dimension. - b = x.shape[0] x = self.backbone(x) - x = rearrange(x, "b c h w -> c b (h w)", b=b) + x = rearrange(x, "b c h w -> b w c h") + if self.adaptive_pool is not None: + x = self.adaptive_pool(x) + x = x.squeeze(3) # Sequence predictions. x, _ = self.rnn(x) - # Sequence to classifcation layer. + # Sequence to classification layer. x = self.decoder(x) return x |