diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 6be0ac5..ea54d83 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,27 +1,35 @@ """PyTorch Lightning model for base Transformers.""" from typing import Dict, List, Optional, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import LitBaseModel -class LitTransformerModel(LitBaseModel): +@attr.s +class TransformerLitModel(LitBaseModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", - mapping: Optional[List[str]] = None, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) + network: Type[nn.Module] = attr.ib() + criterion_config: DictConfig = attr.ib(converter=DictConfig) + optimizer_config: DictConfig = attr.ib(converter=DictConfig) + lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) + monitor: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + def __attrs_post_init__(self) -> None: + super().__init__( + network=self.network, + optimizer_config=self.optimizer_config, + lr_scheduler_config=self.lr_scheduler_config, + criterion_config=self.criterion_config, + monitor=self.monitor, + ) self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) |