diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
commit | 240f5e9f20032e82515fa66ce784619527d1041e (patch) | |
tree | b002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /text_recognizer/models | |
parent | d12f70402371dda586d457af2a3df7fb5b3130ad (diff) |
Add VQGAN and loss function
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 81 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/vqgan.py | 135 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 12 |
4 files changed, 193 insertions, 39 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index ab3fa35..8b68ed9 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -24,8 +24,8 @@ class BaseLitModel(LightningModule): network: Type[nn.Module] = attr.ib() mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() - optimizer_config: DictConfig = attr.ib() - lr_scheduler_config: DictConfig = attr.ib() + optimizer_configs: DictConfig = attr.ib() + lr_scheduler_configs: DictConfig = attr.ib() train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -45,40 +45,55 @@ class BaseLitModel(LightningModule): ) -> None: optimizer.zero_grad(set_to_none=True) - def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]: """Configures the optimizer.""" - log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate( - self.optimizer_config, params=self.network.parameters() - ) - - def _configure_lr_scheduler( - self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + optimizers = [] + for optimizer_config in self.optimizer_configs.values(): + network = getattr(self, optimizer_config.parameters) + del optimizer_config.parameters + log.info(f"Instantiating optimizer <{optimizer_config._target_}>") + optimizers.append( + hydra.utils.instantiate( + self.optimizer_config, params=network.parameters() + ) + ) + return optimizers + + def _configure_lr_schedulers( + self, optimizers: List[Type[torch.optim.Optimizer]] + ) -> List[Dict[str, Any]]: """Configures the lr scheduler.""" - # Extract non-class arguments. - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval - - log.info( - f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" - ) - scheduler = { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ), - } - return scheduler - - def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: + schedulers = [] + for optimizer, lr_scheduler_config in zip( + optimizers, self.lr_scheduler_configs.values() + ): + # Extract non-class arguments. + monitor = lr_scheduler_config.monitor + interval = lr_scheduler_config.interval + del lr_scheduler_config.monitor + del lr_scheduler_config.interval + + log.info( + f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>" + ) + scheduler = { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + lr_scheduler_config, optimizer=optimizer + ), + } + schedulers.append(scheduler) + + return schedulers + + def configure_optimizers( + self, + ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" - optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] + optimizers = self._configure_optimizer() + schedulers = self._configure_lr_scheduler(optimizers) + return optimizers, schedulers def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 5fb84a7..75f7523 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -60,6 +60,8 @@ class TransformerLitModel(BaseLitModel): pred = self(data) self.val_cer(pred, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("val/acc", self.test_acc, on_step=False, on_epoch=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -69,6 +71,8 @@ class TransformerLitModel(BaseLitModel): pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def predict(self, x: Tensor) -> Tensor: """Predicts text in image. diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py new file mode 100644 index 0000000..8ff65cc --- /dev/null +++ b/text_recognizer/models/vqgan.py @@ -0,0 +1,135 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Tuple + +import attr +from torch import Tensor + +from text_recognizer.models.base import BaseLitModel +from text_recognizer.criterions.vqgan_loss import VQGANLoss + + +@attr.s(auto_attribs=True, eq=False) +class VQVAELitModel(BaseLitModel): + """A PyTorch Lightning model for transformer networks.""" + + loss_fn: VQGANLoss = attr.ib() + latent_loss_weight: float = attr.ib(default=0.25) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.network(data) + + def training_step( + self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int + ) -> Tensor: + """Training step.""" + data, _ = batch + + reconstructions, vq_loss = self(data) + loss = self.loss_fn(reconstructions, data) + + if optimizer_idx == 0: + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=optimizer_idx, + stage="train", + ) + self.log( + "train/loss", + loss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return loss + + if optimizer_idx == 1: + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=optimizer_idx, + stage="train", + ) + self.log( + "train/discriminator_loss", + loss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, _ = batch + reconstructions, vq_loss = self(data) + + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=0, + stage="val", + ) + self.log( + "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + self.log( + "val/rec_loss", + log["val/rec_loss"], + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log) + + _, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=1, + stage="val", + ) + self.log_dict(log) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, _ = batch + reconstructions, vq_loss = self(data) + + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=0, + stage="test", + ) + self.log( + "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + self.log( + "test/rec_loss", + log["test/rec_loss"], + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log) + + _, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=1, + stage="test", + ) + self.log_dict(log) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef9a59a..56229b3 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -28,8 +28,8 @@ class VQVAELitModel(BaseLitModel): self.log("train/vq_loss", vq_loss) self.log("train/loss", loss) - self.train_acc(reconstructions, data) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) + # self.train_acc(reconstructions, data) + # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -42,8 +42,8 @@ class VQVAELitModel(BaseLitModel): self.log("val/vq_loss", vq_loss) self.log("val/loss", loss, prog_bar=True) - self.val_acc(reconstructions, data) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + # self.val_acc(reconstructions, data) + # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -53,5 +53,5 @@ class VQVAELitModel(BaseLitModel): loss = loss + self.latent_loss_weight * vq_loss self.log("test/vq_loss", vq_loss) self.log("test/loss", loss) - self.test_acc(reconstructions, data) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + # self.test_acc(reconstructions, data) + # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) |