diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
commit | 31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch) | |
tree | f529d975d18d718a5d646e93f746d8be6f2f5cfe /text_recognizer | |
parent | 36964354407d0fdf73bdca2f611fee1664860197 (diff) |
Reformat test for CER
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 1 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 30 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 32 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 16 |
4 files changed, 63 insertions, 16 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 51050fc..d2529b4 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -71,5 +71,6 @@ class IAMExtendedParagraphs(BaseDataModule): ) return basic + data + def show_dataset_info() -> None: load_and_print_info(IAMExtendedParagraphs) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index e86b478..0c70625 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,6 +6,7 @@ import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torchmetrics from text_recognizer import networks @@ -13,7 +14,14 @@ from text_recognizer import networks class BaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" - def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None: + def __init__( + self, + network_args: Dict, + optimizer_args: Dict, + lr_scheduler_args: Dict, + criterion_args: Dict, + monitor: str = "val_loss", + ) -> None: super().__init__() self.monitor = monitor self.network = getattr(networks, network_args["type"])(**network_args["args"]) @@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule): self.loss_fn = self.configure_criterion(criterion_args) # Accuracy metric - self.train_acc = pl.metrics.Accuracy() - self.val_acc = pl.metrics.Accuracy() - self.test_acc = pl.metrics.Accuracy() + self.train_acc = torchmetrics.Accuracy() + self.val_acc = torchmetrics.Accuracy() + self.test_acc = torchmetrics.Accuracy() @staticmethod def configure_criterion(criterion_args: Dict) -> Type[nn.Module]: @@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule): optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) args = {} or self.lr_scheduler_args["args"] - scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args) - return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor} + scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( + **args + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": self.monitor, + } def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" @@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule): loss = self.loss_fn(logits, targets) self.log("train_loss", loss) self.train_acc(logits, targets) - self.log("train_acc": self.train_acc, on_step=False, on_epoch=True) + self.log("train_acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule): logits = self(data) self.test_acc(logits, targets) self.log("test_acc", self.test_acc, on_step=False, on_epoch=True) - - diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py new file mode 100644 index 0000000..58d0537 --- /dev/null +++ b/text_recognizer/models/metrics.py @@ -0,0 +1,32 @@ +"""Character Error Rate (CER).""" +from typing import Sequence + +import editdistance +import torch +from torch import Tensor +import torchmetrics + + +class CharacterErrorRate(torchmetrics.Metric): + """Character error rate metric, computed using Levenshtein distance.""" + + def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + super().__init__() + self.ignore_tokens = set(ignore_tokens) + self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, targets: Tensor) -> None: + """Update CER.""" + bsz = preds.shape[0] + for index in range(bsz): + pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] + target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + distance = editdistance.distance(pred, target) + error = distance / max(len(pred), len(target)) + self.error += error + self.total += bsz + + def compute(self) -> Tensor: + """Compute CER.""" + return self.error / self.total diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d03f630..d67d297 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module): self.dropout = nn.Dropout(p=dropout_rate) pe = self.make_pe(hidden_dim, max_len) self.register_buffer("pe", pe) - + @staticmethod def make_pe(hidden_dim: int, max_len: int) -> Tensor: """Returns positional encoding.""" @@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module): class PositionalEncoding2D(nn.Module): """Positional encodings for feature maps.""" - def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None: + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: super().__init__() if hidden_dim % 2 != 0: raise ValueError(f"Embedding depth {hidden_dim} is not even!") @@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module): def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: """Returns 2d postional encoding.""" - pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2] + pe_h = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [H, 1, D // 2] pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2] + pe_w = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [W, 1, D // 2] pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] @@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module): # Assumes x hase shape [B, D, H, W] if x.shape[1] != self.pe.shape[0]: raise ValueError("Hidden dimensions does not match.") - x += self.pe[:, :x.shape[2], :x.shape[3]] + x += self.pe[:, : x.shape[2], : x.shape[3]] return x - - |