diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
commit | 4d1f2cef39688871d2caafce42a09316381a27ae (patch) | |
tree | 0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer | |
parent | f0481decdad9afb52494e9e95996deef843ef233 (diff) |
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/callbacks/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/callbacks/wandb_callbacks.py | 8 | ||||
-rw-r--r-- | text_recognizer/criterions/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/criterions/label_smoothing_loss.py (renamed from text_recognizer/networks/loss/label_smoothing_loss.py) | 0 | ||||
-rw-r--r-- | text_recognizer/data/base_data_module.py | 14 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 24 | ||||
-rw-r--r-- | text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 11 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 30 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 14 | ||||
-rw-r--r-- | text_recognizer/networks/loss/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/util.py | 7 |
13 files changed, 70 insertions, 50 deletions
diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py index e69de29..82d8ce3 100644 --- a/text_recognizer/callbacks/__init__.py +++ b/text_recognizer/callbacks/__init__.py @@ -0,0 +1 @@ +"""Module for PyTorch Lightning callbacks.""" diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py index 900c3b1..4186b4a 100644 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -29,7 +29,7 @@ class WatchModel(Callback): log: str = attr.ib(default="gradients") log_freq: int = attr.ib(default=100) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -44,7 +44,7 @@ class UploadCodeAsArtifact(Callback): project_dir: Path = attr.ib(converter=Path) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -65,7 +65,7 @@ class UploadCheckpointAsArtifact(Callback): ckpt_dir: Path = attr.ib(converter=Path) upload_best_only: bool = attr.ib() - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -90,7 +90,7 @@ class LogTextPredictions(Callback): num_samples: int = attr.ib(default=8) ready: bool = attr.ib(default=True) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_sanity_check_start( diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py new file mode 100644 index 0000000..5b0a7ab --- /dev/null +++ b/text_recognizer/criterions/__init__.py @@ -0,0 +1 @@ +"""Module with custom loss functions.""" diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py index 40a7609..40a7609 100644 --- a/text_recognizer/networks/loss/label_smoothing_loss.py +++ b/text_recognizer/criterions/label_smoothing_loss.py diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 8b5c188..de5628f 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Dict -import pytorch_lightning as pl +import attr +import pytorch_lightning as LightningDataModule from torch.utils.data import DataLoader @@ -14,14 +15,17 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -class BaseDataModule(pl.LightningDataModule): +@attr.s +class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + batch_size: int = attr.ib(default=16) + num_workers: int = attr.ib(default=0) + + def __attrs_pre_init__(self) -> None: super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + def __attrs_post_init__(self) -> None: # Placeholders for subclasses. self.dims = None self.output_dims = None diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 8d644d4..4318dfb 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,11 +1,13 @@ """Base PyTorch Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union +import attr import torch from torch import Tensor from torch.utils.data import Dataset +@attr.s class BaseDataset(Dataset): """ Base Dataset class that processes data and targets through optional transfroms. @@ -18,19 +20,17 @@ class BaseDataset(Dataset): target transforms. """ - def __init__( - self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: - if len(data) != len(targets): + data: Union[Sequence, Tensor] = attr.ib() + targets: Union[Sequence, Tensor] = attr.ib() + transform: Callable = attr.ib() + target_transform: Callable = attr.ib() + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def __attrs_post_init__(self) -> None: + if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") - self.data = data - self.targets = targets - self.transform = transform - self.target_transform = target_transform def __len__(self) -> int: """Return the length of the dataset.""" diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 5ac2510..1982daf 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,3 +1 @@ """PyTorch Lightning models modules.""" -from .transformer import LitTransformerModel -from .vqvae import LitVQVAEModel diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 4e803eb..8dc7a36 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Union, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import attr import hydra @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class LitBaseModel(pl.LightningModule): +class BaseLitModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule): val_acc = attr.ib(init=False) test_acc = attr.ib(init=False) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self): - self.loss_fn = self.configure_criterion() + def __attrs_post_init__(self) -> None: + self.loss_fn = self._configure_criterion() # Accuracy metric self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy() self.test_acc = torchmetrics.Accuracy() - @staticmethod def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 6be0ac5..ea54d83 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,27 +1,35 @@ """PyTorch Lightning model for base Transformers.""" from typing import Dict, List, Optional, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import LitBaseModel -class LitTransformerModel(LitBaseModel): +@attr.s +class TransformerLitModel(LitBaseModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", - mapping: Optional[List[str]] = None, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) + network: Type[nn.Module] = attr.ib() + criterion_config: DictConfig = attr.ib(converter=DictConfig) + optimizer_config: DictConfig = attr.ib(converter=DictConfig) + lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) + monitor: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + def __attrs_post_init__(self) -> None: + super().__init__( + network=self.network, + optimizer_config=self.optimizer_config, + lr_scheduler_config=self.lr_scheduler_config, + criterion_config=self.criterion_config, + monitor=self.monitor, + ) self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 18e8691..7dc950f 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -18,7 +18,7 @@ class LitVQVAEModel(LitBaseModel): optimizer: Union[DictConfig, Dict], lr_scheduler: Union[DictConfig, Dict], criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", + monitor: str = "val/loss", *args: Any, **kwargs: Dict, ) -> None: @@ -50,7 +50,7 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self.log("train_loss", loss) + self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -59,7 +59,7 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self.log("val_loss", loss, prog_bar=True) + self.log("val/loss", loss, prog_bar=True) title = "val_pred_examples" self._log_prediction(data, reconstructions, title) diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py new file mode 100644 index 0000000..da69311 --- /dev/null +++ b/text_recognizer/networks/cnn_tranformer.py @@ -0,0 +1,14 @@ +"""Vision transformer for character recognition.""" +from typing import Type + +import attr +from torch import nn, Tensor + + +@attr.s +class CnnTransformer(nn.Module): + def __attrs_pre_init__(self) -> None: + super().__init__() + + backbone: Type[nn.Module] = attr.ib() + head = Type[nn.Module] = attr.ib() diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py deleted file mode 100644 index cb83608..0000000 --- a/text_recognizer/networks/loss/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Loss module.""" -from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 05b10a8..109bf4d 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,10 +1,6 @@ """Miscellaneous neural network functionality.""" -import importlib -from pathlib import Path -from typing import Dict, NamedTuple, Union, Type +from typing import Type -from loguru import logger -import torch from torch import nn @@ -19,6 +15,7 @@ def activation_function(activation: str) -> Type[nn.Module]: ["none", nn.Identity()], ["relu", nn.ReLU(inplace=True)], ["selu", nn.SELU(inplace=True)], + ["mish", nn.Mish(inplace=True)], ] ) return activation_fns[activation.lower()] |