diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bf3bc08..63fe5a7 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,6 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -9,31 +8,34 @@ from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor -import torchmetrics +from torchmetrics import Accuracy from text_recognizer.data.mappings.base import AbstractMapping -@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_configs: DictConfig, + lr_scheduler_configs: Optional[DictConfig], + mapping: Type[AbstractMapping], + ) -> None: super().__init__() - network: Type[nn.Module] = field() - loss_fn: Type[nn.Module] = field() - optimizer_configs: DictConfig = field() - lr_scheduler_configs: Optional[DictConfig] = field() - mapping: Type[AbstractMapping] = field() - - # Placeholders - train_acc: torchmetrics.Accuracy = field( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) - test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + self.network = network + self.loss_fn = loss_fn + self.optimizer_configs = optimizer_configs + self.lr_scheduler_configs = lr_scheduler_configs + self.mapping = mapping + + # Placeholders + self.train_acc = Accuracy() + self.val_acc = Accuracy() + self.test_acc = Accuracy() def optimizer_zero_grad( self, |