summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/progress_bar.py
blob: 7829fa009dc1e8f04373669662c5cbd2e16176b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Progress bar callback for the training loop."""
from typing import Dict, Optional

from tqdm import tqdm
from training.trainer.callbacks import Callback


class ProgressBar(Callback):
    """A TQDM progress bar for the training loop."""

    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
        """Initializes the tqdm callback."""
        self.epochs = epochs
        self.log_batch_frequency = log_batch_frequency
        self.progress_bar = None
        self.val_metrics = {}

    def _configure_progress_bar(self) -> None:
        """Configures the tqdm progress bar with custom bar format."""
        self.progress_bar = tqdm(
            total=len(self.model.train_dataloader()),
            leave=False,
            unit="steps",
            mininterval=self.log_batch_frequency,
            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
        )

    def _key_abbreviations(self, logs: Dict) -> Dict:
        """Changes the length of keys, so that the progress bar fits better."""

        def rename(key: str) -> str:
            """Renames accuracy to acc."""
            return key.replace("accuracy", "acc")

        return {rename(key): value for key, value in logs.items()}

    # def on_fit_begin(self) -> None:
    #     """Creates a tqdm progress bar."""
    #     self._configure_progress_bar()

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
        """Updates the description with the current epoch."""
        if epoch == 1:
            self._configure_progress_bar()
        else:
            self.progress_bar.reset()
        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")

    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
        """At the end of each epoch, the validation metrics are updated to the progress bar."""
        self.val_metrics = logs
        self.progress_bar.set_postfix(**self._key_abbreviations(logs))
        self.progress_bar.update()

    def on_train_batch_end(self, batch: int, logs: Dict) -> None:
        """Updates the progress bar for each training step."""
        if self.val_metrics:
            logs.update(self.val_metrics)
        self.progress_bar.set_postfix(**self._key_abbreviations(logs))
        self.progress_bar.update()

    def on_fit_end(self) -> None:
        """Closes the tqdm progress bar."""
        self.progress_bar.close()