diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 21:43:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 21:43:39 +0200 |
commit | 82f4acabe24e5171c40afa2939a4777ba87bcc30 (patch) | |
tree | 4d327fa26e4662a0447a66375442a9adeb13ea3d /text_recognizer/models/vqgan.py | |
parent | 240f5e9f20032e82515fa66ce784619527d1041e (diff) |
Add training of VQGAN
Diffstat (limited to 'text_recognizer/models/vqgan.py')
-rw-r--r-- | text_recognizer/models/vqgan.py | 24 |
1 files changed, 2 insertions, 22 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 8ff65cc..80653b6 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -9,7 +9,7 @@ from text_recognizer.criterions.vqgan_loss import VQGANLoss @attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): +class VQGANLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" loss_fn: VQGANLoss = attr.ib() @@ -26,7 +26,6 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) - loss = self.loss_fn(reconstructions, data) if optimizer_idx == 0: loss, log = self.loss_fn( @@ -81,14 +80,6 @@ class VQVAELitModel(BaseLitModel): self.log( "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True ) - self.log( - "val/rec_loss", - log["val/rec_loss"], - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) self.log_dict(log) _, log = self.loss_fn( @@ -105,24 +96,13 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) - loss, log = self.loss_fn( + _, log = self.loss_fn( data=data, reconstructions=reconstructions, vq_loss=vq_loss, optimizer_idx=0, stage="test", ) - self.log( - "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True - ) - self.log( - "test/rec_loss", - log["test/rec_loss"], - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) self.log_dict(log) _, log = self.loss_fn( |