diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 24 |
1 files changed, 10 insertions, 14 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 821cb69..bf3bc08 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,7 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -import attr +from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -14,7 +14,7 @@ import torchmetrics from text_recognizer.data.mappings.base import AbstractMapping -@attr.s(eq=False) +@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" @@ -22,22 +22,18 @@ class BaseLitModel(LightningModule): """Pre init constructor.""" super().__init__() - network: Type[nn.Module] = attr.ib() - loss_fn: Type[nn.Module] = attr.ib() - optimizer_configs: DictConfig = attr.ib() - lr_scheduler_configs: Optional[DictConfig] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + network: Type[nn.Module] = field() + loss_fn: Type[nn.Module] = field() + optimizer_configs: DictConfig = field() + lr_scheduler_configs: Optional[DictConfig] = field() + mapping: Type[AbstractMapping] = field() # Placeholders - train_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - test_acc: torchmetrics.Accuracy = attr.ib( + train_acc: torchmetrics.Accuracy = field( init=False, default=torchmetrics.Accuracy() ) + val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) def optimizer_zero_grad( self, |