diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models/metrics.py | |
parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) |
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models/metrics.py')
-rw-r--r-- | text_recognizer/models/metrics.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 58d0537..4117ae2 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,18 +1,23 @@ """Character Error Rate (CER).""" -from typing import Sequence +from typing import Set, Sequence +import attr import editdistance import torch from torch import Tensor -import torchmetrics +from torchmetrics import Metric -class CharacterErrorRate(torchmetrics.Metric): +@attr.s +class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + ignore_tokens: Set = attr.ib(converter=set) + error: Tensor = attr.ib(init=False) + total: Tensor = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: super().__init__() - self.ignore_tokens = set(ignore_tokens) self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") |