diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models | |
parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) |
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 5 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 15 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 26 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 34 |
4 files changed, 24 insertions, 56 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8dc7a36..f95df0f 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -5,7 +5,7 @@ import attr import hydra import loguru.logger as log from omegaconf import DictConfig -import pytorch_lightning as pl +import pytorch_lightning as LightningModule import torch from torch import nn from torch import Tensor @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class BaseLitModel(pl.LightningModule): +class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule): """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 58d0537..4117ae2 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,18 +1,23 @@ """Character Error Rate (CER).""" -from typing import Sequence +from typing import Set, Sequence +import attr import editdistance import torch from torch import Tensor -import torchmetrics +from torchmetrics import Metric -class CharacterErrorRate(torchmetrics.Metric): +@attr.s +class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + ignore_tokens: Set = attr.ib(converter=set) + error: Tensor = attr.ib(init=False) + total: Tensor = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: super().__init__() - self.ignore_tokens = set(ignore_tokens) self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index ea54d83..8c9fe8a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -2,35 +2,24 @@ from typing import Dict, List, Optional, Union, Tuple, Type import attr +import hydra from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -@attr.s -class TransformerLitModel(LitBaseModel): +@attr.s(auto_attribs=True) +class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - network: Type[nn.Module] = attr.ib() - criterion_config: DictConfig = attr.ib(converter=DictConfig) - optimizer_config: DictConfig = attr.ib(converter=DictConfig) - lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) - monitor: str = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + mapping_config: DictConfig = attr.ib(converter=DictConfig) def __attrs_post_init__(self) -> None: - super().__init__( - network=self.network, - optimizer_config=self.optimizer_config, - lr_scheduler_config=self.lr_scheduler_config, - criterion_config=self.criterion_config, - monitor=self.monitor, - ) - self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.mapping, ignore_tokens = self._configure_mapping() self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) @@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel): return self.network.predict(data) @staticmethod - def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: + def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: """Configure mapping.""" # TODO: Fix me!!! + # Load config with hydra mapping, inverse_mapping, _ = emnist_mapping(["\n"]) start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 7dc950f..0172163 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -1,49 +1,23 @@ """PyTorch Lightning model for base Transformers.""" from typing import Any, Dict, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn from torch import Tensor import wandb -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -class LitVQVAEModel(LitBaseModel): +@attr.s(auto_attribs=True) +class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val/loss", - *args: Any, - **kwargs: Dict, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) - def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction( - self, data: Tensor, reconstructions: Tensor, title: str - ) -> None: - """Logs prediction on image with wandb.""" - try: - self.logger.experiment.log( - { - title: [ - wandb.Image(data[0]), - wandb.Image(reconstructions[0]), - ] - } - ) - except AttributeError: - pass - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch |