summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae.py
blob: 22da01819ac75e992133285b399443074e9c5a33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""PyTorch Lightning model for base Transformers."""
from typing import Any, Dict, Union, Tuple, Type

import attr
from omegaconf import DictConfig
from torch import nn
from torch import Tensor
import wandb

from text_recognizer.models.base import BaseLitModel


@attr.s(auto_attribs=True, eq=False)
class VQVAELitModel(BaseLitModel):
    """A PyTorch Lightning model for transformer networks."""

    def forward(self, data: Tensor) -> Tensor:
        """Forward pass with the transformer network."""
        return self.network.predict(data)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """Training step."""
        data, _ = batch
        reconstructions, vq_loss = self.network(data)
        loss = self.loss_fn(reconstructions, data)
        loss += vq_loss
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Validation step."""
        data, _ = batch
        reconstructions, vq_loss = self.network(data)
        loss = self.loss_fn(reconstructions, data)
        loss += vq_loss
        self.log("val/loss", loss, prog_bar=True)

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Test step."""
        data, _ = batch
        reconstructions, vq_loss = self.network(data)
        loss = self.loss_fn(reconstructions, data)
        loss += vq_loss
        self.log("test/loss", loss)