diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models/base.py | |
parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) |
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8dc7a36..f95df0f 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -5,7 +5,7 @@ import attr import hydra import loguru.logger as log from omegaconf import DictConfig -import pytorch_lightning as pl +import pytorch_lightning as LightningModule import torch from torch import nn from torch import Tensor @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class BaseLitModel(pl.LightningModule): +class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule): """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: |