summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/models/transformer.py
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py30
1 files changed, 19 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6be0ac5..ea54d83 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,27 +1,35 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Dict, List, Optional, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import LitBaseModel
-class LitTransformerModel(LitBaseModel):
+@attr.s
+class TransformerLitModel(LitBaseModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
- mapping: Optional[List[str]] = None,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+ network: Type[nn.Module] = attr.ib()
+ criterion_config: DictConfig = attr.ib(converter=DictConfig)
+ optimizer_config: DictConfig = attr.ib(converter=DictConfig)
+ lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ monitor: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ def __attrs_post_init__(self) -> None:
+ super().__init__(
+ network=self.network,
+ optimizer_config=self.optimizer_config,
+ lr_scheduler_config=self.lr_scheduler_config,
+ criterion_config=self.criterion_config,
+ monitor=self.monitor,
+ )
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)