summaryrefslogtreecommitdiff
path: root/text_recognizer/model/transformer.py
blob: 783e134fe773d546ec851a6d755d2b9dc3f9bc35 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Lightning model for transformer networks."""
from typing import Callable, Optional, Tuple, Type

import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from torchmetrics import CharErrorRate, WordErrorRate

from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.decoder.greedy_decoder import GreedyDecoder

from .base import LitBase


class LitTransformer(LitBase):
    def __init__(
        self,
        network: Type[nn.Module],
        loss_fn: Type[nn.Module],
        optimizer_config: DictConfig,
        tokenizer: Tokenizer,
        decoder: Callable = GreedyDecoder,
        lr_scheduler_config: Optional[DictConfig] = None,
        max_output_len: int = 682,
    ) -> None:
        super().__init__(
            network,
            loss_fn,
            optimizer_config,
            lr_scheduler_config,
            tokenizer,
        )
        self.max_output_len = max_output_len
        self.val_cer = CharErrorRate()
        self.test_cer = CharErrorRate()
        self.val_wer = WordErrorRate()
        self.test_wer = WordErrorRate()
        self.decoder = decoder

    def forward(self, data: Tensor) -> Tensor:
        """Autoregressive forward pass."""
        return self.predict(data)

    def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor:
        """Non-autoregressive forward pass."""
        logits = self.network(data, targets)  # [B, N, C]
        return logits.permute(0, 2, 1)  # [B, C, N]

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> dict:
        """Training step."""
        data, targets = batch
        logits = self.teacher_forward(data, targets[:, :-1])
        loss = self.loss_fn(logits, targets[:, 1:])

        self.log("train/loss", loss, prog_bar=True)

        outputs = {"loss": loss}

        if self.is_logged_batch():
            preds, gts = self.tokenizer.decode_logits(
                logits
            ), self.tokenizer.batch_decode(targets)
            outputs.update({"predictions": preds, "ground_truths": gts})

        return outputs

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> dict:
        """Validation step."""
        data, targets = batch
        preds = self(data)
        preds, gts = self.tokenizer.batch_decode(preds), self.tokenizer.batch_decode(
            targets
        )

        self.val_cer(preds, gts)
        self.val_wer(preds, gts)

        self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)

        outputs = {}
        self.add_on_first_batch(
            {"predictions": preds, "ground_truths": gts}, outputs, batch_idx
        )
        return outputs

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> dict:
        """Test step."""
        data, targets = batch
        preds = self(data)
        preds, gts = self.tokenizer.batch_decode(preds), self.tokenizer.batch_decode(
            targets
        )

        self.test_cer(preds, gts)
        self.test_wer(preds, gts)

        self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)

        outputs = {}
        self.add_on_first_batch(
            {"predictions": preds, "ground_truths": gts}, outputs, batch_idx
        )
        return outputs

    @torch.no_grad()
    def predict(self, x: Tensor) -> Tensor:
        return self.decoder(x)