1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
"""Progress bar callback for the training loop."""
from typing import Dict, Optional
from tqdm import tqdm
from training.trainer.callbacks import Callback
class ProgressBar(Callback):
"""A TQDM progress bar for the training loop."""
def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
"""Initializes the tqdm callback."""
self.epochs = epochs
self.log_batch_frequency = log_batch_frequency
self.progress_bar = None
self.val_metrics = {}
def _configure_progress_bar(self) -> None:
"""Configures the tqdm progress bar with custom bar format."""
self.progress_bar = tqdm(
total=len(self.model.train_dataloader()),
leave=False,
unit="steps",
mininterval=self.log_batch_frequency,
bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
)
def _key_abbreviations(self, logs: Dict) -> Dict:
"""Changes the length of keys, so that the progress bar fits better."""
def rename(key: str) -> str:
"""Renames accuracy to acc."""
return key.replace("accuracy", "acc")
return {rename(key): value for key, value in logs.items()}
# def on_fit_begin(self) -> None:
# """Creates a tqdm progress bar."""
# self._configure_progress_bar()
def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
"""Updates the description with the current epoch."""
if epoch == 1:
self._configure_progress_bar()
else:
self.progress_bar.reset()
self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""At the end of each epoch, the validation metrics are updated to the progress bar."""
self.val_metrics = logs
self.progress_bar.set_postfix(**self._key_abbreviations(logs))
self.progress_bar.update()
def on_train_batch_end(self, batch: int, logs: Dict) -> None:
"""Updates the progress bar for each training step."""
if self.val_metrics:
logs.update(self.val_metrics)
self.progress_bar.set_postfix(**self._key_abbreviations(logs))
self.progress_bar.update()
def on_fit_end(self) -> None:
"""Closes the tqdm progress bar."""
self.progress_bar.close()
|