diff options
| author | aktersnurra <grydholm@kth.se> | 2020-12-02 23:48:52 +0100 | 
|---|---|---|
| committer | aktersnurra <grydholm@kth.se> | 2020-12-02 23:48:52 +0100 | 
| commit | 5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (patch) | |
| tree | f2be992554e278857db7d56786dba54a76d439c7 /src/text_recognizer/networks | |
| parent | e3b039c9adb4bce42ede4cb682a3ae71e797539a (diff) | |
| parent | 8e3985c9cde6666e4314973312135ec1c7a025b9 (diff) | |
Merge branch 'master' of github.com:aktersnurra/text-recognizer
Diffstat (limited to 'src/text_recognizer/networks')
| -rw-r--r-- | src/text_recognizer/networks/__init__.py | 5 | ||||
| -rw-r--r-- | src/text_recognizer/networks/crnn.py | 12 | ||||
| -rw-r--r-- | src/text_recognizer/networks/metrics.py | 107 | 
3 files changed, 119 insertions, 5 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 67e245c..1635039 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -4,6 +4,7 @@ from .crnn import ConvolutionalRecurrentNetwork  from .ctc import greedy_decoder  from .densenet import DenseNet  from .lenet import LeNet +from .metrics import accuracy, accuracy_ignore_pad, cer, wer  from .mlp import MLP  from .residual_network import ResidualNetwork, ResidualNetworkEncoder  from .transformer import Transformer @@ -11,6 +12,9 @@ from .util import sliding_window  from .wide_resnet import WideResidualNetwork  __all__ = [ +    "accuracy", +    "accuracy_ignore_pad", +    "cer",      "CNNTransformer",      "ConvolutionalRecurrentNetwork",      "DenseNet", @@ -21,5 +25,6 @@ __all__ = [      "ResidualNetworkEncoder",      "sliding_window",      "Transformer", +    "wer",      "WideResidualNetwork",  ] diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py index 9747429..778e232 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/src/text_recognizer/networks/crnn.py @@ -1,4 +1,4 @@ -"""LSTM with CTC for handwritten text recognition within a line.""" +"""CRNN for handwritten text recognition."""  from typing import Dict, Tuple  from einops import rearrange, reduce @@ -89,20 +89,22 @@ class ConvolutionalRecurrentNetwork(nn.Module):              x = self.backbone(x) -            # Avgerage pooling. +            # Average pooling.              if self.avg_pool:                  x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)              else:                  x = rearrange(x, "(b t) h -> t b h", b=b, t=t)          else:              # Encode the entire image with a CNN, and use the channels as temporal dimension. -            b = x.shape[0]              x = self.backbone(x) -            x = rearrange(x, "b c h w -> c b (h w)", b=b) +            x = rearrange(x, "b c h w -> b w c h") +            if self.adaptive_pool is not None: +                x = self.adaptive_pool(x) +            x = x.squeeze(3)          # Sequence predictions.          x, _ = self.rnn(x) -        # Sequence to classifcation layer. +        # Sequence to classification layer.          x = self.decoder(x)          return x diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py new file mode 100644 index 0000000..af9adb5 --- /dev/null +++ b/src/text_recognizer/networks/metrics.py @@ -0,0 +1,107 @@ +"""Utility functions for models.""" +import Levenshtein as Lev +import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder + + +def accuracy_ignore_pad( +    output: Tensor, +    target: Tensor, +    pad_index: int = 79, +    eos_index: int = 81, +    seq_len: int = 97, +) -> float: +    """Sets all predictions after eos to pad.""" +    start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) +    end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) +    for start, stop in zip(start_indices, end_indices): +        output[start + 1 : stop] = pad_index + +    return accuracy(output, target) + + +def accuracy(outputs: Tensor, labels: Tensor,) -> float: +    """Computes the accuracy. + +    Args: +        outputs (Tensor): The output from the network. +        labels (Tensor): Ground truth labels. + +    Returns: +        float: The accuracy for the batch. + +    """ + +    _, predicted = torch.max(outputs, dim=-1) + +    acc = (predicted == labels).sum().float() / labels.shape[0] +    acc = acc.item() +    return acc + + +def cer(outputs: Tensor, targets: Tensor) -> float: +    """Computes the character error rate. + +    Args: +        outputs (Tensor): The output from the network. +        targets (Tensor): Ground truth labels. + +    Returns: +        float: The cer for the batch. + +    """ +    target_lengths = torch.full( +        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, +    ) +    decoded_predictions, decoded_targets = greedy_decoder( +        outputs, targets, target_lengths +    ) + +    lev_dist = 0 + +    for prediction, target in zip(decoded_predictions, decoded_targets): +        prediction = "".join(prediction) +        target = "".join(target) +        prediction, target = ( +            prediction.replace(" ", ""), +            target.replace(" ", ""), +        ) +        lev_dist += Lev.distance(prediction, target) +    return lev_dist / len(decoded_predictions) + + +def wer(outputs: Tensor, targets: Tensor) -> float: +    """Computes the Word error rate. + +    Args: +        outputs (Tensor): The output from the network. +        targets (Tensor): Ground truth labels. + +    Returns: +        float: The wer for the batch. + +    """ +    target_lengths = torch.full( +        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, +    ) +    decoded_predictions, decoded_targets = greedy_decoder( +        outputs, targets, target_lengths +    ) + +    lev_dist = 0 + +    for prediction, target in zip(decoded_predictions, decoded_targets): +        prediction = "".join(prediction) +        target = "".join(target) + +        b = set(prediction.split() + target.split()) +        word2char = dict(zip(b, range(len(b)))) + +        w1 = [chr(word2char[w]) for w in prediction.split()] +        w2 = [chr(word2char[w]) for w in target.split()] + +        lev_dist += Lev.distance("".join(w1), "".join(w2)) + +    return lev_dist / len(decoded_predictions)  |