From 527bb98b191d82b308de1585047e06056258d08d Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 18 Nov 2020 20:56:19 +0100 Subject: Some minor changes. --- src/text_recognizer/datasets/transforms.py | 15 ++-- src/text_recognizer/line_predictor.py | 4 +- src/text_recognizer/models/__init__.py | 5 -- src/text_recognizer/models/metrics.py | 107 ----------------------------- src/text_recognizer/networks/__init__.py | 5 ++ src/text_recognizer/networks/metrics.py | 107 +++++++++++++++++++++++++++++ 6 files changed, 124 insertions(+), 119 deletions(-) delete mode 100644 src/text_recognizer/models/metrics.py create mode 100644 src/text_recognizer/networks/metrics.py (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index d1ca127..1ec23dc 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -4,7 +4,7 @@ from PIL import Image import torch from torch import Tensor import torch.nn.functional as F -from torchvision.transforms import Compose, ToPILImage, ToTensor +from torchvision.transforms import Compose, RandomAffine, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -66,9 +66,14 @@ class AddTokens: return target -class Whitening: - """Whitening of Tensor, i.e. set mean to zero and std to one.""" +class ApplyContrast: + """Sets everything below a threshold to zero, i.e. increase contrast.""" + + def __init__(self, low: float = 0.0, high: float = 0.25) -> None: + self.low = low + self.high = high def __call__(self, x: Tensor) -> Tensor: - """Apply the whitening.""" - return (x - x.mean()) / x.std() + """Apply mask binary mask to input tensor.""" + mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) + return x * mask diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py index 981e2c9..8e348fe 100644 --- a/src/text_recognizer/line_predictor.py +++ b/src/text_recognizer/line_predictor.py @@ -6,7 +6,7 @@ import numpy as np from torch import nn from text_recognizer import datasets, networks -from text_recognizer.models import VisionTransformerModel +from text_recognizer.models import TransformerModel from text_recognizer.util import read_image @@ -16,7 +16,7 @@ class LinePredictor: def __init__(self, dataset: str, network_fn: str) -> None: network_fn = getattr(networks, network_fn) dataset = getattr(datasets, dataset) - self.model = VisionTransformerModel(network_fn=network_fn, dataset=dataset) + self.model = TransformerModel(network_fn=network_fn, dataset=dataset) self.model.eval() def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 53340f1..bf89404 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,16 +2,11 @@ from .base import Model from .character_model import CharacterModel from .crnn_model import CRNNModel -from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .transformer_model import TransformerModel __all__ = [ - "accuracy", - "accuracy_ignore_pad", - "cer", "CharacterModel", "CRNNModel", "Model", "TransformerModel", - "wer", ] diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py deleted file mode 100644 index af9adb5..0000000 --- a/src/text_recognizer/models/metrics.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Utility functions for models.""" -import Levenshtein as Lev -import torch -from torch import Tensor - -from text_recognizer.networks import greedy_decoder - - -def accuracy_ignore_pad( - output: Tensor, - target: Tensor, - pad_index: int = 79, - eos_index: int = 81, - seq_len: int = 97, -) -> float: - """Sets all predictions after eos to pad.""" - start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) - end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) - for start, stop in zip(start_indices, end_indices): - output[start + 1 : stop] = pad_index - - return accuracy(output, target) - - -def accuracy(outputs: Tensor, labels: Tensor,) -> float: - """Computes the accuracy. - - Args: - outputs (Tensor): The output from the network. - labels (Tensor): Ground truth labels. - - Returns: - float: The accuracy for the batch. - - """ - - _, predicted = torch.max(outputs, dim=-1) - - acc = (predicted == labels).sum().float() / labels.shape[0] - acc = acc.item() - return acc - - -def cer(outputs: Tensor, targets: Tensor) -> float: - """Computes the character error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - - Returns: - float: The cer for the batch. - - """ - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - prediction, target = ( - prediction.replace(" ", ""), - target.replace(" ", ""), - ) - lev_dist += Lev.distance(prediction, target) - return lev_dist / len(decoded_predictions) - - -def wer(outputs: Tensor, targets: Tensor) -> float: - """Computes the Word error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - - Returns: - float: The wer for the batch. - - """ - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - - b = set(prediction.split() + target.split()) - word2char = dict(zip(b, range(len(b)))) - - w1 = [chr(word2char[w]) for w in prediction.split()] - w2 = [chr(word2char[w]) for w in target.split()] - - lev_dist += Lev.distance("".join(w1), "".join(w2)) - - return lev_dist / len(decoded_predictions) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 2cc1137..078d771 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -4,6 +4,7 @@ from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet from .lenet import LeNet +from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .sparse_mlp import SparseMLP @@ -12,6 +13,9 @@ from .util import sliding_window from .wide_resnet import WideResidualNetwork __all__ = [ + "accuracy", + "accuracy_ignore_pad", + "cer", "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", @@ -23,5 +27,6 @@ __all__ = [ "sliding_window", "Transformer", "SparseMLP", + "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py new file mode 100644 index 0000000..af9adb5 --- /dev/null +++ b/src/text_recognizer/networks/metrics.py @@ -0,0 +1,107 @@ +"""Utility functions for models.""" +import Levenshtein as Lev +import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder + + +def accuracy_ignore_pad( + output: Tensor, + target: Tensor, + pad_index: int = 79, + eos_index: int = 81, + seq_len: int = 97, +) -> float: + """Sets all predictions after eos to pad.""" + start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) + end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) + for start, stop in zip(start_indices, end_indices): + output[start + 1 : stop] = pad_index + + return accuracy(output, target) + + +def accuracy(outputs: Tensor, labels: Tensor,) -> float: + """Computes the accuracy. + + Args: + outputs (Tensor): The output from the network. + labels (Tensor): Ground truth labels. + + Returns: + float: The accuracy for the batch. + + """ + + _, predicted = torch.max(outputs, dim=-1) + + acc = (predicted == labels).sum().float() / labels.shape[0] + acc = acc.item() + return acc + + +def cer(outputs: Tensor, targets: Tensor) -> float: + """Computes the character error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The cer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + prediction, target = ( + prediction.replace(" ", ""), + target.replace(" ", ""), + ) + lev_dist += Lev.distance(prediction, target) + return lev_dist / len(decoded_predictions) + + +def wer(outputs: Tensor, targets: Tensor) -> float: + """Computes the Word error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The wer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + + b = set(prediction.split() + target.split()) + word2char = dict(zip(b, range(len(b)))) + + w1 = [chr(word2char[w]) for w in prediction.split()] + w2 = [chr(word2char[w]) for w in target.split()] + + lev_dist += Lev.distance("".join(w1), "".join(w2)) + + return lev_dist / len(decoded_predictions) -- cgit v1.2.3-70-g09d2