summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/progress_bar.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/callbacks/progress_bar.py')
-rw-r--r--src/training/trainer/callbacks/progress_bar.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
index 1970747..7829fa0 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -18,11 +18,11 @@ class ProgressBar(Callback):
def _configure_progress_bar(self) -> None:
"""Configures the tqdm progress bar with custom bar format."""
self.progress_bar = tqdm(
- total=len(self.model.data_loaders["train"]),
- leave=True,
- unit="step",
+ total=len(self.model.train_dataloader()),
+ leave=False,
+ unit="steps",
mininterval=self.log_batch_frequency,
- bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
)
def _key_abbreviations(self, logs: Dict) -> Dict:
@@ -34,13 +34,16 @@ class ProgressBar(Callback):
return {rename(key): value for key, value in logs.items()}
- def on_fit_begin(self) -> None:
- """Creates a tqdm progress bar."""
- self._configure_progress_bar()
+ # def on_fit_begin(self) -> None:
+ # """Creates a tqdm progress bar."""
+ # self._configure_progress_bar()
def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
"""Updates the description with the current epoch."""
- self.progress_bar.reset()
+ if epoch == 1:
+ self._configure_progress_bar()
+ else:
+ self.progress_bar.reset()
self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
def on_epoch_end(self, epoch: int, logs: Dict) -> None: