diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
commit | 4d1f2cef39688871d2caafce42a09316381a27ae (patch) | |
tree | 0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/models | |
parent | f0481decdad9afb52494e9e95996deef843ef233 (diff) |
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 11 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 30 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 6 |
4 files changed, 27 insertions, 22 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 5ac2510..1982daf 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,3 +1 @@ """PyTorch Lightning models modules.""" -from .transformer import LitTransformerModel -from .vqvae import LitVQVAEModel diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 4e803eb..8dc7a36 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Union, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import attr import hydra @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class LitBaseModel(pl.LightningModule): +class BaseLitModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule): val_acc = attr.ib(init=False) test_acc = attr.ib(init=False) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self): - self.loss_fn = self.configure_criterion() + def __attrs_post_init__(self) -> None: + self.loss_fn = self._configure_criterion() # Accuracy metric self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy() self.test_acc = torchmetrics.Accuracy() - @staticmethod def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 6be0ac5..ea54d83 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,27 +1,35 @@ """PyTorch Lightning model for base Transformers.""" from typing import Dict, List, Optional, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import LitBaseModel -class LitTransformerModel(LitBaseModel): +@attr.s +class TransformerLitModel(LitBaseModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", - mapping: Optional[List[str]] = None, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) + network: Type[nn.Module] = attr.ib() + criterion_config: DictConfig = attr.ib(converter=DictConfig) + optimizer_config: DictConfig = attr.ib(converter=DictConfig) + lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) + monitor: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + def __attrs_post_init__(self) -> None: + super().__init__( + network=self.network, + optimizer_config=self.optimizer_config, + lr_scheduler_config=self.lr_scheduler_config, + criterion_config=self.criterion_config, + monitor=self.monitor, + ) self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 18e8691..7dc950f 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -18,7 +18,7 @@ class LitVQVAEModel(LitBaseModel): optimizer: Union[DictConfig, Dict], lr_scheduler: Union[DictConfig, Dict], criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", + monitor: str = "val/loss", *args: Any, **kwargs: Dict, ) -> None: @@ -50,7 +50,7 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self.log("train_loss", loss) + self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -59,7 +59,7 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self.log("val_loss", loss, prog_bar=True) + self.log("val/loss", loss, prog_bar=True) title = "val_pred_examples" self._log_prediction(data, reconstructions, title) |