blob: 78f58543a92ce95c826b1cb612f2f5f9e306d478 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
"""Character Error Rate (CER)."""
from typing import Sequence
import torch
import torchmetrics
class WordErrorRate(torchmetrics.WordErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
preds_l = [
[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()
]
targets_l = [
[t for t in target if t not in self.ignore_tokens]
for target in targets.tolist()
]
super().update(preds_l, targets_l)
|