diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models/vqvae.py | |
parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) |
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r-- | text_recognizer/models/vqvae.py | 34 |
1 files changed, 4 insertions, 30 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 7dc950f..0172163 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -1,49 +1,23 @@ """PyTorch Lightning model for base Transformers.""" from typing import Any, Dict, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn from torch import Tensor import wandb -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -class LitVQVAEModel(LitBaseModel): +@attr.s(auto_attribs=True) +class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val/loss", - *args: Any, - **kwargs: Dict, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) - def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction( - self, data: Tensor, reconstructions: Tensor, title: str - ) -> None: - """Logs prediction on image with wandb.""" - try: - self.logger.experiment.log( - { - title: [ - wandb.Image(data[0]), - wandb.Image(reconstructions[0]), - ] - } - ) - except AttributeError: - pass - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch |