summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
authoraktersnurra <grydholm@kth.se>2020-12-02 23:48:52 +0100
committeraktersnurra <grydholm@kth.se>2020-12-02 23:48:52 +0100
commit5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (patch)
treef2be992554e278857db7d56786dba54a76d439c7 /src/text_recognizer/models
parente3b039c9adb4bce42ede4cb682a3ae71e797539a (diff)
parent8e3985c9cde6666e4314973312135ec1c7a025b9 (diff)
Merge branch 'master' of github.com:aktersnurra/text-recognizer
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py5
-rw-r--r--src/text_recognizer/models/metrics.py107
2 files changed, 0 insertions, 112 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index 53340f1..bf89404 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,16 +2,11 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
-from .metrics import accuracy, accuracy_ignore_pad, cer, wer
from .transformer_model import TransformerModel
__all__ = [
- "accuracy",
- "accuracy_ignore_pad",
- "cer",
"CharacterModel",
"CRNNModel",
"Model",
"TransformerModel",
- "wer",
]
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
deleted file mode 100644
index af9adb5..0000000
--- a/src/text_recognizer/models/metrics.py
+++ /dev/null
@@ -1,107 +0,0 @@
-"""Utility functions for models."""
-import Levenshtein as Lev
-import torch
-from torch import Tensor
-
-from text_recognizer.networks import greedy_decoder
-
-
-def accuracy_ignore_pad(
- output: Tensor,
- target: Tensor,
- pad_index: int = 79,
- eos_index: int = 81,
- seq_len: int = 97,
-) -> float:
- """Sets all predictions after eos to pad."""
- start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
- end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
- for start, stop in zip(start_indices, end_indices):
- output[start + 1 : stop] = pad_index
-
- return accuracy(output, target)
-
-
-def accuracy(outputs: Tensor, labels: Tensor,) -> float:
- """Computes the accuracy.
-
- Args:
- outputs (Tensor): The output from the network.
- labels (Tensor): Ground truth labels.
-
- Returns:
- float: The accuracy for the batch.
-
- """
-
- _, predicted = torch.max(outputs, dim=-1)
-
- acc = (predicted == labels).sum().float() / labels.shape[0]
- acc = acc.item()
- return acc
-
-
-def cer(outputs: Tensor, targets: Tensor) -> float:
- """Computes the character error rate.
-
- Args:
- outputs (Tensor): The output from the network.
- targets (Tensor): Ground truth labels.
-
- Returns:
- float: The cer for the batch.
-
- """
- target_lengths = torch.full(
- size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
- )
- decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
- )
-
- lev_dist = 0
-
- for prediction, target in zip(decoded_predictions, decoded_targets):
- prediction = "".join(prediction)
- target = "".join(target)
- prediction, target = (
- prediction.replace(" ", ""),
- target.replace(" ", ""),
- )
- lev_dist += Lev.distance(prediction, target)
- return lev_dist / len(decoded_predictions)
-
-
-def wer(outputs: Tensor, targets: Tensor) -> float:
- """Computes the Word error rate.
-
- Args:
- outputs (Tensor): The output from the network.
- targets (Tensor): Ground truth labels.
-
- Returns:
- float: The wer for the batch.
-
- """
- target_lengths = torch.full(
- size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
- )
- decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
- )
-
- lev_dist = 0
-
- for prediction, target in zip(decoded_predictions, decoded_targets):
- prediction = "".join(prediction)
- target = "".join(target)
-
- b = set(prediction.split() + target.split())
- word2char = dict(zip(b, range(len(b))))
-
- w1 = [chr(word2char[w]) for w in prediction.split()]
- w2 = [chr(word2char[w]) for w in target.split()]
-
- lev_dist += Lev.distance("".join(w1), "".join(w2))
-
- return lev_dist / len(decoded_predictions)