summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqgan.py
blob: 8ff65cc1db0e6c3c8e22b3d96863acf394df6c13 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""PyTorch Lightning model for base Transformers."""
from typing import Tuple

import attr
from torch import Tensor

from text_recognizer.models.base import BaseLitModel
from text_recognizer.criterions.vqgan_loss import VQGANLoss


@attr.s(auto_attribs=True, eq=False)
class VQVAELitModel(BaseLitModel):
    """A PyTorch Lightning model for transformer networks."""

    loss_fn: VQGANLoss = attr.ib()
    latent_loss_weight: float = attr.ib(default=0.25)

    def forward(self, data: Tensor) -> Tensor:
        """Forward pass with the transformer network."""
        return self.network(data)

    def training_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
    ) -> Tensor:
        """Training step."""
        data, _ = batch

        reconstructions, vq_loss = self(data)
        loss = self.loss_fn(reconstructions, data)

        if optimizer_idx == 0:
            loss, log = self.loss_fn(
                data=data,
                reconstructions=reconstructions,
                vq_loss=vq_loss,
                optimizer_idx=optimizer_idx,
                stage="train",
            )
            self.log(
                "train/loss",
                loss,
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=True,
            )
            self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return loss

        if optimizer_idx == 1:
            loss, log = self.loss_fn(
                data=data,
                reconstructions=reconstructions,
                vq_loss=vq_loss,
                optimizer_idx=optimizer_idx,
                stage="train",
            )
            self.log(
                "train/discriminator_loss",
                loss,
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=True,
            )
            self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Validation step."""
        data, _ = batch
        reconstructions, vq_loss = self(data)

        loss, log = self.loss_fn(
            data=data,
            reconstructions=reconstructions,
            vq_loss=vq_loss,
            optimizer_idx=0,
            stage="val",
        )
        self.log(
            "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
        )
        self.log(
            "val/rec_loss",
            log["val/rec_loss"],
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True,
        )
        self.log_dict(log)

        _, log = self.loss_fn(
            data=data,
            reconstructions=reconstructions,
            vq_loss=vq_loss,
            optimizer_idx=1,
            stage="val",
        )
        self.log_dict(log)

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Test step."""
        data, _ = batch
        reconstructions, vq_loss = self(data)

        loss, log = self.loss_fn(
            data=data,
            reconstructions=reconstructions,
            vq_loss=vq_loss,
            optimizer_idx=0,
            stage="test",
        )
        self.log(
            "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
        )
        self.log(
            "test/rec_loss",
            log["test/rec_loss"],
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True,
        )
        self.log_dict(log)

        _, log = self.loss_fn(
            data=data,
            reconstructions=reconstructions,
            vq_loss=vq_loss,
            optimizer_idx=1,
            stage="test",
        )
        self.log_dict(log)