diff options
author | aktersnurra <grydholm@kth.se> | 2020-12-02 23:48:52 +0100 |
---|---|---|
committer | aktersnurra <grydholm@kth.se> | 2020-12-02 23:48:52 +0100 |
commit | 5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (patch) | |
tree | f2be992554e278857db7d56786dba54a76d439c7 /src/text_recognizer/networks | |
parent | e3b039c9adb4bce42ede4cb682a3ae71e797539a (diff) | |
parent | 8e3985c9cde6666e4314973312135ec1c7a025b9 (diff) |
Merge branch 'master' of github.com:aktersnurra/text-recognizer
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/networks/crnn.py | 12 | ||||
-rw-r--r-- | src/text_recognizer/networks/metrics.py | 107 |
3 files changed, 119 insertions, 5 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 67e245c..1635039 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -4,6 +4,7 @@ from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet from .lenet import LeNet +from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .transformer import Transformer @@ -11,6 +12,9 @@ from .util import sliding_window from .wide_resnet import WideResidualNetwork __all__ = [ + "accuracy", + "accuracy_ignore_pad", + "cer", "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", @@ -21,5 +25,6 @@ __all__ = [ "ResidualNetworkEncoder", "sliding_window", "Transformer", + "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py index 9747429..778e232 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/src/text_recognizer/networks/crnn.py @@ -1,4 +1,4 @@ -"""LSTM with CTC for handwritten text recognition within a line.""" +"""CRNN for handwritten text recognition.""" from typing import Dict, Tuple from einops import rearrange, reduce @@ -89,20 +89,22 @@ class ConvolutionalRecurrentNetwork(nn.Module): x = self.backbone(x) - # Avgerage pooling. + # Average pooling. if self.avg_pool: x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) else: x = rearrange(x, "(b t) h -> t b h", b=b, t=t) else: # Encode the entire image with a CNN, and use the channels as temporal dimension. - b = x.shape[0] x = self.backbone(x) - x = rearrange(x, "b c h w -> c b (h w)", b=b) + x = rearrange(x, "b c h w -> b w c h") + if self.adaptive_pool is not None: + x = self.adaptive_pool(x) + x = x.squeeze(3) # Sequence predictions. x, _ = self.rnn(x) - # Sequence to classifcation layer. + # Sequence to classification layer. x = self.decoder(x) return x diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py new file mode 100644 index 0000000..af9adb5 --- /dev/null +++ b/src/text_recognizer/networks/metrics.py @@ -0,0 +1,107 @@ +"""Utility functions for models.""" +import Levenshtein as Lev +import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder + + +def accuracy_ignore_pad( + output: Tensor, + target: Tensor, + pad_index: int = 79, + eos_index: int = 81, + seq_len: int = 97, +) -> float: + """Sets all predictions after eos to pad.""" + start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) + end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) + for start, stop in zip(start_indices, end_indices): + output[start + 1 : stop] = pad_index + + return accuracy(output, target) + + +def accuracy(outputs: Tensor, labels: Tensor,) -> float: + """Computes the accuracy. + + Args: + outputs (Tensor): The output from the network. + labels (Tensor): Ground truth labels. + + Returns: + float: The accuracy for the batch. + + """ + + _, predicted = torch.max(outputs, dim=-1) + + acc = (predicted == labels).sum().float() / labels.shape[0] + acc = acc.item() + return acc + + +def cer(outputs: Tensor, targets: Tensor) -> float: + """Computes the character error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The cer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + prediction, target = ( + prediction.replace(" ", ""), + target.replace(" ", ""), + ) + lev_dist += Lev.distance(prediction, target) + return lev_dist / len(decoded_predictions) + + +def wer(outputs: Tensor, targets: Tensor) -> float: + """Computes the Word error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The wer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + + b = set(prediction.split() + target.split()) + word2char = dict(zip(b, range(len(b)))) + + w1 = [chr(word2char[w]) for w in prediction.split()] + w2 = [chr(word2char[w]) for w in target.split()] + + lev_dist += Lev.distance("".join(w1), "".join(w2)) + + return lev_dist / len(decoded_predictions) |