diff options
Diffstat (limited to 'src/training')
-rw-r--r-- | src/training/experiments/iam_line_ctc_experiment.yml | 94 | ||||
-rw-r--r-- | src/training/experiments/line_ctc_experiment.yml | 97 | ||||
-rw-r--r-- | src/training/experiments/sample_experiment.yml | 13 | ||||
-rw-r--r-- | src/training/run_experiment.py | 79 | ||||
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 14 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 121 | ||||
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 1 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 1 | ||||
-rw-r--r-- | src/training/trainer/train.py | 22 |
9 files changed, 135 insertions, 307 deletions
diff --git a/src/training/experiments/iam_line_ctc_experiment.yml b/src/training/experiments/iam_line_ctc_experiment.yml deleted file mode 100644 index 141c74e..0000000 --- a/src/training/experiments/iam_line_ctc_experiment.yml +++ /dev/null @@ -1,94 +0,0 @@ -experiment_group: Sample Experiments -experiments: - - train_args: - batch_size: 24 - max_epochs: 128 - dataset: - type: IamLinesDataset - args: - subsample_fraction: null - transform: null - target_transform: null - train_args: - num_workers: 6 - train_fraction: 0.85 - model: LineCTCModel - metrics: [cer, wer] - network: - type: LineRecurrentNetwork - args: - # encoder: ResidualNetworkEncoder - # encoder_args: - # in_channels: 1 - # num_classes: 80 - # depths: [2, 2] - # block_sizes: [128, 128] - # activation: SELU - # stn: false - encoder: WideResidualNetwork - encoder_args: - in_channels: 1 - num_classes: 80 - depth: 16 - num_layers: 4 - width_factor: 2 - dropout_rate: 0.2 - activation: selu - use_decoder: false - flatten: true - input_size: 256 - hidden_size: 128 - num_layers: 2 - num_classes: 80 - patch_size: [28, 14] - stride: [1, 5] - criterion: - type: CTCLoss - args: - blank: 79 - optimizer: - type: AdamW - args: - lr: 1.e-03 - betas: [0.9, 0.999] - eps: 1.e-08 - weight_decay: false - amsgrad: false - # lr_scheduler: - # type: OneCycleLR - # args: - # max_lr: 1.e-02 - # epochs: null - # anneal_strategy: linear - lr_scheduler: - type: CosineAnnealingLR - args: - T_max: null - swa_args: - start: 75 - lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR] - callback_args: - Checkpoint: - monitor: val_loss - mode: min - ProgressBar: - epochs: null - # log_batch_frequency: 100 - # EarlyStopping: - # monitor: val_loss - # min_delta: 0.0 - # patience: 7 - # mode: min - WandbCallback: - log_batch_frequency: 10 - WandbImageLogger: - num_examples: 6 - # OneCycleLR: - # null - SWA: - null - verbosity: 1 # 0, 1, 2 - resume_experiment: null - test: true - test_metric: test_cer diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml index c21c6a2..432d1cc 100644 --- a/src/training/experiments/line_ctc_experiment.yml +++ b/src/training/experiments/line_ctc_experiment.yml @@ -1,55 +1,46 @@ -experiment_group: Sample Experiments +experiment_group: Lines Experiments experiments: - train_args: - batch_size: 64 - max_epochs: 32 + batch_size: 42 + max_epochs: &max_epochs 32 dataset: - type: EmnistLinesDataset + type: IamLinesDataset args: - subsample_fraction: 0.33 - max_length: 34 - min_overlap: 0 - max_overlap: 0.33 - num_samples: 10000 - seed: 4711 - blank: true + subsample_fraction: null + transform: null + target_transform: null train_args: - num_workers: 6 + num_workers: 8 train_fraction: 0.85 model: LineCTCModel metrics: [cer, wer] network: type: LineRecurrentNetwork args: - # encoder: ResidualNetworkEncoder - # encoder_args: - # in_channels: 1 - # num_classes: 81 - # depths: [2, 2] - # block_sizes: [64, 128] - # activation: SELU - # stn: false - encoder: WideResidualNetwork - encoder_args: + backbone: ResidualNetwork + backbone_args: in_channels: 1 - num_classes: 81 - depth: 16 - num_layers: 4 - width_factor: 2 - dropout_rate: 0.2 + num_classes: 64 # Embedding + depths: [2,2] + block_sizes: [32,64] activation: selu - use_decoder: false - flatten: true - input_size: 256 - hidden_size: 128 + stn: false + # encoder: ResidualNetwork + # encoder_args: + # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt + # freeze: false + flatten: false + input_size: 64 + hidden_size: 64 + bidirectional: true num_layers: 2 - num_classes: 81 - patch_size: [28, 14] - stride: [1, 5] + num_classes: 80 + patch_size: [28, 18] + stride: [1, 4] criterion: type: CTCLoss args: - blank: 80 + blank: 79 optimizer: type: AdamW args: @@ -58,40 +49,42 @@ experiments: eps: 1.e-08 weight_decay: 5.e-4 amsgrad: false - # lr_scheduler: - # type: OneCycleLR - # args: - # max_lr: 1.e-03 - # epochs: null - # anneal_strategy: linear lr_scheduler: - type: CosineAnnealingLR + type: OneCycleLR args: - T_max: null + max_lr: 1.e-02 + epochs: *max_epochs + anneal_strategy: cos + pct_start: 0.475 + cycle_momentum: true + base_momentum: 0.85 + max_momentum: 0.9 + div_factor: 10 + final_div_factor: 10000 + interval: step + # lr_scheduler: + # type: CosineAnnealingLR + # args: + # T_max: *max_epochs swa_args: - start: 4 + start: 24 lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR] + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping] callback_args: Checkpoint: monitor: val_loss mode: min ProgressBar: - epochs: null - log_batch_frequency: 100 + epochs: *max_epochs # EarlyStopping: # monitor: val_loss # min_delta: 0.0 - # patience: 5 + # patience: 10 # mode: min WandbCallback: log_batch_frequency: 10 WandbImageLogger: num_examples: 6 - # OneCycleLR: - # null - SWA: - null verbosity: 1 # 0, 1, 2 resume_experiment: null test: true diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 17e220e..8664a15 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -2,7 +2,7 @@ experiment_group: Sample Experiments experiments: - train_args: batch_size: 256 - max_epochs: 32 + max_epochs: &max_epochs 32 dataset: type: EmnistDataset args: @@ -66,16 +66,17 @@ experiments: # type: OneCycleLR # args: # max_lr: 1.e-03 - # epochs: null + # epochs: *max_epochs # anneal_strategy: linear lr_scheduler: type: CosineAnnealingLR args: - T_max: null + T_max: *max_epochs + interval: epoch swa_args: start: 2 lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping, SWA] # OneCycleLR] + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping] callback_args: Checkpoint: monitor: val_accuracy @@ -92,10 +93,6 @@ experiments: WandbImageLogger: num_examples: 4 use_transpose: true - # OneCycleLR: - # null - SWA: - null verbosity: 0 # 0, 1, 2 resume_experiment: null test: true diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 286b0c6..a347d9f 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Tuple, Type import click from loguru import logger +import numpy as np import torch from tqdm import tqdm from training.gpu_manager import GPUManager @@ -20,11 +21,12 @@ import yaml from text_recognizer.models import Model +from text_recognizer.networks import losses EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - +CUSTOM_LOSSES = ["EmbeddingLoss"] DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16} @@ -69,21 +71,6 @@ def create_experiment_dir(experiment_config: Dict) -> Path: return experiment_dir, log_dir, model_dir -def check_args(args: Dict, train_args: Dict) -> Dict: - """Checks that the arguments are not None.""" - args = args or {} - - # I just want to set total epochs in train args, instead of changing all parameter. - if "epochs" in args and args["epochs"] is None: - args["epochs"] = train_args["max_epochs"] - - # For CosineAnnealingLR. - if "T_max" in args and args["T_max"] is None: - args["T_max"] = train_args["max_epochs"] - - return args or {} - - def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" # Import the data loader arguments. @@ -115,8 +102,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] network_args = experiment_config["network"].get("args", {}) # Criterion - criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) + if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + criterion_ = getattr(losses, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) + else: + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) # Optimizers optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) @@ -129,13 +120,11 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] lr_scheduler_ = getattr( torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] ) - lr_scheduler_args = check_args( - experiment_config["lr_scheduler"].get("args", {}), train_args - ) + lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} # SWA scheduler. if "swa_args" in experiment_config: - swa_args = check_args(experiment_config.get("swa_args", {}), train_args) + swa_args = experiment_config.get("swa_args", {}) or {} else: swa_args = None @@ -159,19 +148,15 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: """Configure a callback list for trainer.""" - train_args = experiment_config.get("train_args", {}) - if "Checkpoint" in experiment_config["callback_args"]: experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir - # Callbacks + # Initializes callbacks. callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [ - getattr(callback_modules, callback)( - **check_args(experiment_config["callback_args"][callback], train_args) - ) - for callback in experiment_config["callbacks"] - ] + callbacks = [] + for callback in experiment_config["callbacks"]: + args = experiment_config["callback_args"][callback] or {} + callbacks.append(getattr(callback_modules, callback)(**args)) return callbacks @@ -207,11 +192,35 @@ def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> model.load_checkpoint(checkpoint_path) +def evaluate_embedding(model: Type[Model]) -> Dict: + """Evaluates the embedding space.""" + from pytorch_metric_learning import testers + from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator + + accuracy_calculator = AccuracyCalculator( + include=("mean_average_precision_at_r",), k=10 + ) + + def get_all_embeddings(model: Type[Model]) -> Tuple: + tester = testers.BaseTester() + return tester.get_all_embeddings(model.test_dataset, model.network) + + embeddings, labels = get_all_embeddings(model) + logger.info("Computing embedding accuracy") + accuracies = accuracy_calculator.get_accuracy( + embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True + ) + logger.info( + f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" + ) + return accuracies + + def run_experiment( experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False ) -> None: """Runs an experiment.""" - logger.info(f"Experiment config: {json.dumps(experiment_config, indent=2)}") + logger.info(f"Experiment config: {json.dumps(experiment_config)}") # Create new experiment. experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) @@ -272,7 +281,11 @@ def run_experiment( model.load_from_checkpoint(model_dir / "best.pt") logger.info("Running inference on test set.") - score = trainer.test(model) + if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + logger.info("Evaluating embedding.") + score = evaluate_embedding(model) + else: + score = trainer.test(model) logger.info(f"Test set evaluation: {score}") diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index c81e4bf..e1bd858 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -3,12 +3,7 @@ from .base import Callback, CallbackList from .checkpoint import Checkpoint from .early_stopping import EarlyStopping from .lr_schedulers import ( - CosineAnnealingLR, - CyclicLR, - MultiStepLR, - OneCycleLR, - ReduceLROnPlateau, - StepLR, + LRScheduler, SWA, ) from .progress_bar import ProgressBar @@ -18,15 +13,10 @@ __all__ = [ "Callback", "CallbackList", "Checkpoint", - "CosineAnnealingLR", "EarlyStopping", + "LRScheduler", "WandbCallback", "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", "ProgressBar", - "ReduceLROnPlateau", - "StepLR", "SWA", ] diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index bb41d2d..907e292 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback from text_recognizer.models import Model -class StepLR(Callback): - """Callback for StepLR.""" +class LRScheduler(Callback): + """Generic learning rate scheduler callback.""" def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" super().__init__() - self.lr_scheduler = None def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler + if self.interval == "epoch": + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class CosineAnnealingLR(Callback): - """Callback for Cosine Annealing.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() + if self.interval == "step": + self.lr_scheduler.step() class SWA(Callback): @@ -122,21 +36,32 @@ class SWA(Callback): def __init__(self) -> None: """Initializes the callback.""" super().__init__() + self.lr_scheduler = None + self.interval = None self.swa_scheduler = None + self.swa_start = None + self.current_epoch = 1 def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.swa_start = self.model.swa_start - self.swa_scheduler = self.model.lr_scheduler - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] + self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] + self.swa_start = self.model.swa_scheduler["swa_start"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" if epoch > self.swa_start: self.model.swa_network.update_parameters(self.model.network) self.swa_scheduler.step() - else: + elif self.interval == "epoch": + self.lr_scheduler.step() + self.current_epoch = epoch + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if self.current_epoch < self.swa_start and self.interval == "step": self.lr_scheduler.step() def on_fit_end(self) -> None: diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 7829fa0..6c4305a 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -11,6 +11,7 @@ class ProgressBar(Callback): def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: """Initializes the tqdm callback.""" self.epochs = epochs + print(epochs, type(epochs)) self.log_batch_frequency = log_batch_frequency self.progress_bar = None self.val_metrics = {} diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 6643a44..d2df4d7 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -32,6 +32,7 @@ class WandbCallback(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs training metrics.""" if logs is not None: + logs["lr"] = self.model.optimizer.param_groups[0]["lr"] self._on_batch_end(batch, logs) def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index b240157..bd6a491 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch import Tensor from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList +from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA from training.trainer.util import log_val_metric, RunningAverage import wandb @@ -47,8 +47,14 @@ class Trainer: self.model = None def _configure_callbacks(self) -> None: + """Instantiate the CallbackList.""" if not self.callbacks_configured: - # Instantiate a CallbackList. + # If learning rate schedulers are present, they need to be added to the callbacks. + if self.model.swa_scheduler is not None: + self.callbacks.append(SWA()) + elif self.model.lr_scheduler is not None: + self.callbacks.append(LRScheduler()) + self.callbacks = CallbackList(self.model, self.callbacks) def compute_metrics( @@ -91,7 +97,7 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -130,7 +136,6 @@ class Trainer: batch: int, samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], - use_swa: bool = False, ) -> Dict: """Performs the validation step.""" # Pass the tensor to the device for computation. @@ -143,10 +148,7 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - if use_swa and self.model.swa_network is None: - output = self.model.swa_network(data) - else: - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -238,7 +240,7 @@ class Trainer: self.model.eval() # Check if SWA network is available. - use_swa = True if self.model.swa_network is not None else False + self.model.use_swa_model() # Running average for the loss. loss_avg = RunningAverage() @@ -247,7 +249,7 @@ class Trainer: summary = [] for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples, loss_avg, use_swa) + metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) # Compute mean of all test metrics. |