1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
"""PyTorch Lightning model for base Transformers."""
from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
import attr
import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@attr.s(auto_attribs=True)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
logits = self.network(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, targets = batch
logits = self.network(data, targets[:-1])
loss = self.loss_fn(logits, targets[1:])
self.log("val/loss", loss, prog_bar=True)
pred = self.network.predict(data)
self.val_cer(pred, targets)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
pred = self.network.predict(data)
self.test_cer(pred, targets)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
|