diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 24 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 10 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 24 |
3 files changed, 27 insertions, 31 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 821cb69..bf3bc08 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,7 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -import attr +from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -14,7 +14,7 @@ import torchmetrics from text_recognizer.data.mappings.base import AbstractMapping -@attr.s(eq=False) +@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" @@ -22,22 +22,18 @@ class BaseLitModel(LightningModule): """Pre init constructor.""" super().__init__() - network: Type[nn.Module] = attr.ib() - loss_fn: Type[nn.Module] = attr.ib() - optimizer_configs: DictConfig = attr.ib() - lr_scheduler_configs: Optional[DictConfig] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + network: Type[nn.Module] = field() + loss_fn: Type[nn.Module] = field() + optimizer_configs: DictConfig = field() + lr_scheduler_configs: Optional[DictConfig] = field() + mapping: Type[AbstractMapping] = field() # Placeholders - train_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - test_acc: torchmetrics.Accuracy = attr.ib( + train_acc: torchmetrics.Accuracy = field( init=False, default=torchmetrics.Accuracy() ) + val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) def optimizer_zero_grad( self, diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index f83c9e4..e59a830 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,20 +1,20 @@ """Character Error Rate (CER).""" from typing import Set -import attr +from attrs import define, field import editdistance import torch from torch import Tensor from torchmetrics import Metric -@attr.s(eq=False) +@define(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set[Tensor] = attr.ib(converter=set) - error: Tensor = attr.ib(init=False) - total: Tensor = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(converter=set) + error: Tensor = field(init=False) + total: Tensor = field(init=False) def __attrs_post_init__(self) -> None: super().__init__() diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7272f46..c5120fe 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,7 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -import attr +from attrs import define, field import torch from torch import Tensor @@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = attr.ib(default=451) - start_token: str = attr.ib(default="<s>") - end_token: str = attr.ib(default="<e>") - pad_token: str = attr.ib(default="<p>") + max_output_len: int = field(default=451) + start_token: str = field(default="<s>") + end_token: str = field(default="<e>") + pad_token: str = field(default="<p>") - start_index: int = attr.ib(init=False) - end_index: int = attr.ib(init=False) - pad_index: int = attr.ib(init=False) + start_index: int = field(init=False) + end_index: int = field(init=False) + pad_index: int = field(init=False) - ignore_indices: Set[Tensor] = attr.ib(init=False) - val_cer: CharacterErrorRate = attr.ib(init=False) - test_cer: CharacterErrorRate = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(init=False) + val_cer: CharacterErrorRate = field(init=False) + test_cer: CharacterErrorRate = field(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" |