diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
commit | 4d1f2cef39688871d2caafce42a09316381a27ae (patch) | |
tree | 0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/models/base.py | |
parent | f0481decdad9afb52494e9e95996deef843ef233 (diff) |
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 4e803eb..8dc7a36 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Union, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import attr import hydra @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class LitBaseModel(pl.LightningModule): +class BaseLitModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule): val_acc = attr.ib(init=False) test_acc = attr.ib(init=False) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self): - self.loss_fn = self.configure_criterion() + def __attrs_post_init__(self) -> None: + self.loss_fn = self._configure_criterion() # Accuracy metric self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy() self.test_acc = torchmetrics.Accuracy() - @staticmethod def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") |