summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py5
-rw-r--r--src/text_recognizer/networks/crnn.py12
-rw-r--r--src/text_recognizer/networks/metrics.py107
3 files changed, 119 insertions, 5 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 67e245c..1635039 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -4,6 +4,7 @@ from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
from .lenet import LeNet
+from .metrics import accuracy, accuracy_ignore_pad, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .transformer import Transformer
@@ -11,6 +12,9 @@ from .util import sliding_window
from .wide_resnet import WideResidualNetwork
__all__ = [
+ "accuracy",
+ "accuracy_ignore_pad",
+ "cer",
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
"DenseNet",
@@ -21,5 +25,6 @@ __all__ = [
"ResidualNetworkEncoder",
"sliding_window",
"Transformer",
+ "wer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
index 9747429..778e232 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -1,4 +1,4 @@
-"""LSTM with CTC for handwritten text recognition within a line."""
+"""CRNN for handwritten text recognition."""
from typing import Dict, Tuple
from einops import rearrange, reduce
@@ -89,20 +89,22 @@ class ConvolutionalRecurrentNetwork(nn.Module):
x = self.backbone(x)
- # Avgerage pooling.
+ # Average pooling.
if self.avg_pool:
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
else:
x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
else:
# Encode the entire image with a CNN, and use the channels as temporal dimension.
- b = x.shape[0]
x = self.backbone(x)
- x = rearrange(x, "b c h w -> c b (h w)", b=b)
+ x = rearrange(x, "b c h w -> b w c h")
+ if self.adaptive_pool is not None:
+ x = self.adaptive_pool(x)
+ x = x.squeeze(3)
# Sequence predictions.
x, _ = self.rnn(x)
- # Sequence to classifcation layer.
+ # Sequence to classification layer.
x = self.decoder(x)
return x
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
new file mode 100644
index 0000000..af9adb5
--- /dev/null
+++ b/src/text_recognizer/networks/metrics.py
@@ -0,0 +1,107 @@
+"""Utility functions for models."""
+import Levenshtein as Lev
+import torch
+from torch import Tensor
+
+from text_recognizer.networks import greedy_decoder
+
+
+def accuracy_ignore_pad(
+ output: Tensor,
+ target: Tensor,
+ pad_index: int = 79,
+ eos_index: int = 81,
+ seq_len: int = 97,
+) -> float:
+ """Sets all predictions after eos to pad."""
+ start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
+ end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
+ for start, stop in zip(start_indices, end_indices):
+ output[start + 1 : stop] = pad_index
+
+ return accuracy(output, target)
+
+
+def accuracy(outputs: Tensor, labels: Tensor,) -> float:
+ """Computes the accuracy.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ labels (Tensor): Ground truth labels.
+
+ Returns:
+ float: The accuracy for the batch.
+
+ """
+
+ _, predicted = torch.max(outputs, dim=-1)
+
+ acc = (predicted == labels).sum().float() / labels.shape[0]
+ acc = acc.item()
+ return acc
+
+
+def cer(outputs: Tensor, targets: Tensor) -> float:
+ """Computes the character error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+
+ Returns:
+ float: The cer for the batch.
+
+ """
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+ prediction, target = (
+ prediction.replace(" ", ""),
+ target.replace(" ", ""),
+ )
+ lev_dist += Lev.distance(prediction, target)
+ return lev_dist / len(decoded_predictions)
+
+
+def wer(outputs: Tensor, targets: Tensor) -> float:
+ """Computes the Word error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+
+ Returns:
+ float: The wer for the batch.
+
+ """
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+
+ b = set(prediction.split() + target.split())
+ word2char = dict(zip(b, range(len(b))))
+
+ w1 = [chr(word2char[w]) for w in prediction.split()]
+ w2 = [chr(word2char[w]) for w in target.split()]
+
+ lev_dist += Lev.distance("".join(w1), "".join(w2))
+
+ return lev_dist / len(decoded_predictions)