diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
commit | 31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch) | |
tree | f529d975d18d718a5d646e93f746d8be6f2f5cfe | |
parent | 36964354407d0fdf73bdca2f611fee1664860197 (diff) |
Reformat test for CER
-rw-r--r-- | poetry.lock | 42 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | tests/test_cer.py | 23 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 1 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 30 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 32 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 16 |
7 files changed, 129 insertions, 17 deletions
diff --git a/poetry.lock b/poetry.lock index f5b3fb0..3e5fcdf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -341,6 +341,14 @@ toml = "*" pipenv = ["pipenv"] [[package]] +name = "editdistance" +version = "0.5.3" +description = "Fast implementation of the edit distance(Levenshtein distance)" +category = "main" +optional = false +python-versions = "*" + +[[package]] name = "einops" version = "0.3.0" description = "A new flavour of deep learning operations" @@ -2188,7 +2196,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "db4253add1258abaf637f127ba576b1ec5e0b415c8f7f93b18ecdca40bc6f042" +content-hash = "478242efb3644e920eb8bce6710d1f0d4548635b0e8c9f758f665814af7f8300" [metadata.files] absl-py = [ @@ -2456,6 +2464,38 @@ dparse = [ {file = "dparse-0.5.1-py3-none-any.whl", hash = "sha256:e953a25e44ebb60a5c6efc2add4420c177f1d8404509da88da9729202f306994"}, {file = "dparse-0.5.1.tar.gz", hash = "sha256:a1b5f169102e1c894f9a7d5ccf6f9402a836a5d24be80a986c7ce9eaed78f367"}, ] +editdistance = [ + {file = "editdistance-0.5.3-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ef4714dc9cf281863dcc3ba6d24c3cae1dde41610a78dcdfae50d743ca71d5e1"}, + {file = "editdistance-0.5.3-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:a322354a8dfb442770902f06552b20df5184e65e84ac90cb799740915eb52212"}, + {file = "editdistance-0.5.3-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:36a4c36d7945f5ecfa1dc92c08635d73b64769cd0af066da774437fe2c7dc80a"}, + {file = "editdistance-0.5.3-cp27-cp27m-win32.whl", hash = "sha256:93e847cc2fbebb34a36b41337a3eb9b2034d4ff9679665b08ecc5c3c313f83a9"}, + {file = "editdistance-0.5.3-cp27-cp27m-win_amd64.whl", hash = "sha256:d4561b602b7675f6a050cdd0e1b652007ce73bb7290019487b8919a44593d74d"}, + {file = "editdistance-0.5.3-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:dddb0d36f698e3c942d0d5934185533d9324fbde975b3e956a19883713e86d33"}, + {file = "editdistance-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1018f0fa857b079c721583c42d2c54800fbe8c7d2c29b354a9724a0b79971cb8"}, + {file = "editdistance-0.5.3-cp33-cp33m-win32.whl", hash = "sha256:810d93e614f35ad2916570f48ff1370ac3c001eb6941d5e836e2c1c6986fafff"}, + {file = "editdistance-0.5.3-cp33-cp33m-win_amd64.whl", hash = "sha256:a96ac49acc7668477c13aff02ca0527c6462b026b78600602dbef04efc9250d3"}, + {file = "editdistance-0.5.3-cp34-cp34m-macosx_10_6_intel.whl", hash = "sha256:a9167d9d5e754abd7ce68da065a636cc161e5063c322efd81159d15001d5272a"}, + {file = "editdistance-0.5.3-cp34-cp34m-manylinux1_i686.whl", hash = "sha256:a10c61df748220b2b9e2949a10aea23ffeded28c07e610e107a8f6a4b5b92782"}, + {file = "editdistance-0.5.3-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:6452d750fbc49c6f04232a840f96b0f1155ff7cb2d953ce1edf075c5a394f3ea"}, + {file = "editdistance-0.5.3-cp34-cp34m-win32.whl", hash = "sha256:1f510e6eb411ec6123ba4ebc086d5882027710d28db174985a74e13fd0eb354f"}, + {file = "editdistance-0.5.3-cp34-cp34m-win_amd64.whl", hash = "sha256:9d6ee66f8de30ec6358083e5ecd7919a5966b38c64012c1672f326c61ff7a15f"}, + {file = "editdistance-0.5.3-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:c1cf5ff98cfdc38046ae0f2d3ccbe1e15b0665234a04783f6558ec0a48e72dc8"}, + {file = "editdistance-0.5.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:5f9c202b1a2f2630f7a0cdd76ad0ad55de4cd700553778c77e37379c6ac8e8bb"}, + {file = "editdistance-0.5.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:553fb295802c399f0f419b616b499c241ffdcb2a70888d1e9d1bd22ba21b122f"}, + {file = "editdistance-0.5.3-cp35-cp35m-win32.whl", hash = "sha256:0834826832e51a6c18032b13b68083e3ebfbf3daf774142ae6f2b17b35580c16"}, + {file = "editdistance-0.5.3-cp35-cp35m-win_amd64.whl", hash = "sha256:6ccfd57221bae661304e7f9495f508aeec8f72e462d97481d55488ded87f5cbc"}, + {file = "editdistance-0.5.3-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:25b39c836347dcbb251a6041fd3d7575b82c365923a4b13c32c699e442b1b644"}, + {file = "editdistance-0.5.3-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fa0047a8d972ab779141eed4713811251c9f6e96e9e8a62caa8d554a0444ff74"}, + {file = "editdistance-0.5.3-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:db65bf1f39964019040434cb924c62c9965bd0df2feb316dbe5de3f09e6a81de"}, + {file = "editdistance-0.5.3-cp36-cp36m-win32.whl", hash = "sha256:cc65c2cd68751a966f7468537b4a6fd7d9107d49e139d8efd5734ee6f48d3126"}, + {file = "editdistance-0.5.3-cp36-cp36m-win_amd64.whl", hash = "sha256:fe7e6a90476976d7e5abc9472acb0311b7cdc76d84190f8f6c317234680c5de3"}, + {file = "editdistance-0.5.3-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:25dd59d7f17a38203c5e433f5b11f64a8d1042d876d0dc00b324dda060d12e81"}, + {file = "editdistance-0.5.3-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:61486173447a153cccbd52eb63947378803f0f2a5bffebbfec500bd77fc5706d"}, + {file = "editdistance-0.5.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:cd49e9b22972b15527d53e06918c14d9fe228ae362a57476d16b0cad3e14e0c8"}, + {file = "editdistance-0.5.3-cp37-cp37m-win32.whl", hash = "sha256:503c6f69f4901d8a63f3748e4b0eccb2a89e6844b0879a7e256cab439297d379"}, + {file = "editdistance-0.5.3-cp37-cp37m-win_amd64.whl", hash = "sha256:ee4ed815bc5137a794095368580334e430ff26c73a05c67e76b39f535b363a0f"}, + {file = "editdistance-0.5.3.tar.gz", hash = "sha256:89d016dda04649b2c49e12b34337755a7b612bfd690420edd50ab31787120c1f"}, +] einops = [ {file = "einops-0.3.0-py2.py3-none-any.whl", hash = "sha256:a91c6190ceff7d513d74ca9fd701dfa6a1ffcdd98ea0ced14350197c07f75c73"}, {file = "einops-0.3.0.tar.gz", hash = "sha256:a3b0935a4556f012cd5fa1851373f63366890a3f6698d117afea55fd2a40c1fc"}, diff --git a/pyproject.toml b/pyproject.toml index 252678f..abf69ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ sentencepiece = "^0.1.95" pytorch-lightning = "^1.2.4" Pillow = "^8.1.2" madgrad = "^1.0" +editdistance = "^0.5.3" +torchmetrics = "^0.2.0" [tool.poetry.dev-dependencies] pytest = "^5.4.2" diff --git a/tests/test_cer.py b/tests/test_cer.py new file mode 100644 index 0000000..30d58b2 --- /dev/null +++ b/tests/test_cer.py @@ -0,0 +1,23 @@ +"""Test the CER metric.""" +import torch + +from text_recognizer.models.metrics import CharacterErrorRate + + +def test_character_error_rate() -> None: + """Test CER computation.""" + metric = CharacterErrorRate([0, 1]) + preds = torch.Tensor( + [ + [0, 2, 2, 3, 3, 1], # error will be 0 + [0, 2, 1, 1, 1, 1], # error will be 0.75 + [0, 2, 2, 4, 4, 1], # error will be 0.5 + ] + ) + + targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]]) + metric(preds, targets) + print(metric.compute()) + assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3) + + diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 51050fc..d2529b4 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -71,5 +71,6 @@ class IAMExtendedParagraphs(BaseDataModule): ) return basic + data + def show_dataset_info() -> None: load_and_print_info(IAMExtendedParagraphs) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index e86b478..0c70625 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,6 +6,7 @@ import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torchmetrics from text_recognizer import networks @@ -13,7 +14,14 @@ from text_recognizer import networks class BaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" - def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None: + def __init__( + self, + network_args: Dict, + optimizer_args: Dict, + lr_scheduler_args: Dict, + criterion_args: Dict, + monitor: str = "val_loss", + ) -> None: super().__init__() self.monitor = monitor self.network = getattr(networks, network_args["type"])(**network_args["args"]) @@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule): self.loss_fn = self.configure_criterion(criterion_args) # Accuracy metric - self.train_acc = pl.metrics.Accuracy() - self.val_acc = pl.metrics.Accuracy() - self.test_acc = pl.metrics.Accuracy() + self.train_acc = torchmetrics.Accuracy() + self.val_acc = torchmetrics.Accuracy() + self.test_acc = torchmetrics.Accuracy() @staticmethod def configure_criterion(criterion_args: Dict) -> Type[nn.Module]: @@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule): optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) args = {} or self.lr_scheduler_args["args"] - scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args) - return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor} + scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( + **args + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": self.monitor, + } def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" @@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule): loss = self.loss_fn(logits, targets) self.log("train_loss", loss) self.train_acc(logits, targets) - self.log("train_acc": self.train_acc, on_step=False, on_epoch=True) + self.log("train_acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule): logits = self(data) self.test_acc(logits, targets) self.log("test_acc", self.test_acc, on_step=False, on_epoch=True) - - diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py new file mode 100644 index 0000000..58d0537 --- /dev/null +++ b/text_recognizer/models/metrics.py @@ -0,0 +1,32 @@ +"""Character Error Rate (CER).""" +from typing import Sequence + +import editdistance +import torch +from torch import Tensor +import torchmetrics + + +class CharacterErrorRate(torchmetrics.Metric): + """Character error rate metric, computed using Levenshtein distance.""" + + def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + super().__init__() + self.ignore_tokens = set(ignore_tokens) + self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, targets: Tensor) -> None: + """Update CER.""" + bsz = preds.shape[0] + for index in range(bsz): + pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] + target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + distance = editdistance.distance(pred, target) + error = distance / max(len(pred), len(target)) + self.error += error + self.total += bsz + + def compute(self) -> Tensor: + """Compute CER.""" + return self.error / self.total diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d03f630..d67d297 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module): self.dropout = nn.Dropout(p=dropout_rate) pe = self.make_pe(hidden_dim, max_len) self.register_buffer("pe", pe) - + @staticmethod def make_pe(hidden_dim: int, max_len: int) -> Tensor: """Returns positional encoding.""" @@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module): class PositionalEncoding2D(nn.Module): """Positional encodings for feature maps.""" - def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None: + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: super().__init__() if hidden_dim % 2 != 0: raise ValueError(f"Embedding depth {hidden_dim} is not even!") @@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module): def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: """Returns 2d postional encoding.""" - pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2] + pe_h = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [H, 1, D // 2] pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2] + pe_w = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [W, 1, D // 2] pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] @@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module): # Assumes x hase shape [B, D, H, W] if x.shape[1] != self.pe.shape[0]: raise ValueError("Hidden dimensions does not match.") - x += self.pe[:, :x.shape[2], :x.shape[3]] + x += self.pe[:, : x.shape[2], : x.shape[3]] return x - - |