diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
commit | 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 (patch) | |
tree | d78df1d7143dc9ff9e29afd4fd6bc7490bc79418 /text_recognizer/models/base.py | |
parent | 8bc4b4cab00a2777a748c10fca9b3ee01e32277c (diff) |
Remove attrs
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bf3bc08..63fe5a7 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,6 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -9,31 +8,34 @@ from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor -import torchmetrics +from torchmetrics import Accuracy from text_recognizer.data.mappings.base import AbstractMapping -@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_configs: DictConfig, + lr_scheduler_configs: Optional[DictConfig], + mapping: Type[AbstractMapping], + ) -> None: super().__init__() - network: Type[nn.Module] = field() - loss_fn: Type[nn.Module] = field() - optimizer_configs: DictConfig = field() - lr_scheduler_configs: Optional[DictConfig] = field() - mapping: Type[AbstractMapping] = field() - - # Placeholders - train_acc: torchmetrics.Accuracy = field( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) - test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + self.network = network + self.loss_fn = loss_fn + self.optimizer_configs = optimizer_configs + self.lr_scheduler_configs = lr_scheduler_configs + self.mapping = mapping + + # Placeholders + self.train_acc = Accuracy() + self.val_acc = Accuracy() + self.test_acc = Accuracy() def optimizer_zero_grad( self, |