summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models/transformer.py
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py26
1 files changed, 8 insertions, 18 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index ea54d83..8c9fe8a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -2,35 +2,24 @@
from typing import Dict, List, Optional, Union, Tuple, Type
import attr
+import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-@attr.s
-class TransformerLitModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
- monitor: str = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
+ mapping_config: DictConfig = attr.ib(converter=DictConfig)
def __attrs_post_init__(self) -> None:
- super().__init__(
- network=self.network,
- optimizer_config=self.optimizer_config,
- lr_scheduler_config=self.lr_scheduler_config,
- criterion_config=self.criterion_config,
- monitor=self.monitor,
- )
- self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.mapping, ignore_tokens = self._configure_mapping()
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
@@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel):
return self.network.predict(data)
@staticmethod
- def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
+ def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
"""Configure mapping."""
# TODO: Fix me!!!
+ # Load config with hydra
mapping, inverse_mapping, _ = emnist_mapping(["\n"])
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]