summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml5
-rw-r--r--training/conf/dataset/iam_extended_paragraphs.yaml6
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml23
-rw-r--r--training/conf/model/lit_vqvae.yaml2
-rw-r--r--training/run.py1
-rw-r--r--training/utils.py2
6 files changed, 23 insertions, 16 deletions
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
new file mode 100644
index 0000000..3070b56
--- /dev/null
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
+batch_size: 32
+num_workers: 12
+train_fraction: 0.8
+augment: true
diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml
deleted file mode 100644
index 6439a15..0000000
--- a/training/conf/dataset/iam_extended_paragraphs.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-type: IAMExtendedParagraphs
-args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index 60a6f27..e8cb5c4 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,8 +1,15 @@
-type: OneCycleLR
-args:
- interval: step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
-monitor: val_loss
+_target_: torch.optim.lr_scheduler.OneCycleLR
+max_lr: 1.0e-3
+total_steps: None
+epochs: None
+steps_per_epoch: None
+pct_start: 0.3
+anneal_strategy: 'cos'
+cycle_momentum: True
+base_momentum: 0.85
+max_momentum: 0.95
+div_factor: 25.0
+final_div_factor: 10000.0
+three_phase: true
+last_epoch: -1
+verbose: false
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 7136dbd..6be37e5 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,3 +1,3 @@
-type: LitVQVAEModel
+_target_: text_recognizer.models.vqvae.VQVAELitModel
args:
mapping: sentence_piece
diff --git a/training/run.py b/training/run.py
index 5f7c927..31da666 100644
--- a/training/run.py
+++ b/training/run.py
@@ -51,6 +51,7 @@ def run(config: DictConfig) -> Optional[float]:
)
# Log hyperparameters
+ log.info("Logging hyperparameters")
utils.log_hyperparameters(config=config, model=model, trainer=trainer)
if config.debug:
diff --git a/training/utils.py b/training/utils.py
index 4c31dc3..140d97e 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -1,4 +1,4 @@
-"""Util functions for training hydra configs and pytorch lightning."""
+"""Util functions for training with hydra and pytorch lightning."""
from typing import Any, List, Type
import warnings