summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py30
1 files changed, 19 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6be0ac5..ea54d83 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,27 +1,35 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Dict, List, Optional, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import LitBaseModel
-class LitTransformerModel(LitBaseModel):
+@attr.s
+class TransformerLitModel(LitBaseModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
- mapping: Optional[List[str]] = None,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+ network: Type[nn.Module] = attr.ib()
+ criterion_config: DictConfig = attr.ib(converter=DictConfig)
+ optimizer_config: DictConfig = attr.ib(converter=DictConfig)
+ lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ monitor: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ def __attrs_post_init__(self) -> None:
+ super().__init__(
+ network=self.network,
+ optimizer_config=self.optimizer_config,
+ lr_scheduler_config=self.lr_scheduler_config,
+ criterion_config=self.criterion_config,
+ monitor=self.monitor,
+ )
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)