diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 4e803eb..8dc7a36 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Union, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import attr import hydra @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class LitBaseModel(pl.LightningModule): +class BaseLitModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule): val_acc = attr.ib(init=False) test_acc = attr.ib(init=False) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self): - self.loss_fn = self.configure_criterion() + def __attrs_post_init__(self) -> None: + self.loss_fn = self._configure_criterion() # Accuracy metric self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy() self.test_acc = torchmetrics.Accuracy() - @staticmethod def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") |