diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
commit | db86cef2d308f58325278061c6aa177a535e7e03 (patch) | |
tree | a013fa85816337269f9cdc5a8992813fa62d299d /text_recognizer/models/base.py | |
parent | b980a281712a5b1ee7ee5bd8f5d4762cd91a070b (diff) |
Replace attr with attrs
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 24 |
1 files changed, 10 insertions, 14 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 821cb69..bf3bc08 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,7 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -import attr +from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -14,7 +14,7 @@ import torchmetrics from text_recognizer.data.mappings.base import AbstractMapping -@attr.s(eq=False) +@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" @@ -22,22 +22,18 @@ class BaseLitModel(LightningModule): """Pre init constructor.""" super().__init__() - network: Type[nn.Module] = attr.ib() - loss_fn: Type[nn.Module] = attr.ib() - optimizer_configs: DictConfig = attr.ib() - lr_scheduler_configs: Optional[DictConfig] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + network: Type[nn.Module] = field() + loss_fn: Type[nn.Module] = field() + optimizer_configs: DictConfig = field() + lr_scheduler_configs: Optional[DictConfig] = field() + mapping: Type[AbstractMapping] = field() # Placeholders - train_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - test_acc: torchmetrics.Accuracy = attr.ib( + train_acc: torchmetrics.Accuracy = field( init=False, default=torchmetrics.Accuracy() ) + val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) def optimizer_zero_grad( self, |