diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8ce5c37..57c5964 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,6 +11,8 @@ from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.data.base_mapping import AbstractMapping + @attr.s(eq=False) class BaseLitModel(LightningModule): @@ -20,12 +22,12 @@ class BaseLitModel(LightningModule): super().__init__() network: Type[nn.Module] = attr.ib() - criterion_config: DictConfig = attr.ib(converter=DictConfig) - optimizer_config: DictConfig = attr.ib(converter=DictConfig) - lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) + mapping: Type[AbstractMapping] = attr.ib() + loss_fn: Type[nn.Module] = attr.ib() + optimizer_config: DictConfig = attr.ib() + lr_scheduler_config: DictConfig = attr.ib() interval: str = attr.ib() monitor: str = attr.ib(default="val/loss") - loss_fn: Type[nn.Module] = attr.ib(init=False) train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -36,12 +38,6 @@ class BaseLitModel(LightningModule): init=False, default=torchmetrics.Accuracy() ) - @loss_fn.default - def configure_criterion(self) -> Type[nn.Module]: - """Returns a loss functions.""" - log.info(f"Instantiating criterion <{self.criterion_config._target_}>") - return hydra.utils.instantiate(self.criterion_config) - def optimizer_zero_grad( self, epoch: int, @@ -54,7 +50,9 @@ class BaseLitModel(LightningModule): def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: """Configures the optimizer.""" log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate(self.optimizer_config, params=self.parameters()) + return hydra.utils.instantiate( + self.optimizer_config, params=self.network.parameters() + ) def _configure_lr_scheduler( self, optimizer: Type[torch.optim.Optimizer] |