diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/training/trainer/callbacks/progress_bar.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/training/trainer/callbacks/progress_bar.py')
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 1970747..7829fa0 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -18,11 +18,11 @@ class ProgressBar(Callback): def _configure_progress_bar(self) -> None: """Configures the tqdm progress bar with custom bar format.""" self.progress_bar = tqdm( - total=len(self.model.data_loaders["train"]), - leave=True, - unit="step", + total=len(self.model.train_dataloader()), + leave=False, + unit="steps", mininterval=self.log_batch_frequency, - bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", ) def _key_abbreviations(self, logs: Dict) -> Dict: @@ -34,13 +34,16 @@ class ProgressBar(Callback): return {rename(key): value for key, value in logs.items()} - def on_fit_begin(self) -> None: - """Creates a tqdm progress bar.""" - self._configure_progress_bar() + # def on_fit_begin(self) -> None: + # """Creates a tqdm progress bar.""" + # self._configure_progress_bar() def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: """Updates the description with the current epoch.""" - self.progress_bar.reset() + if epoch == 1: + self._configure_progress_bar() + else: + self.progress_bar.reset() self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") def on_epoch_end(self, epoch: int, logs: Dict) -> None: |