summaryrefslogtreecommitdiff
path: root/text_recognizer/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-08 23:38:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-08 23:38:03 +0200
commite388cd95c77d37a51324cff9d84a809421bf97d3 (patch)
treed585545f85d03ea8a6907daba254821fddeb1589 /text_recognizer/training
parentf4629a0d4149d5870c9fd8ce83ff5d391bd7ddd3 (diff)
Bug fixes word pieces
Diffstat (limited to 'text_recognizer/training')
-rw-r--r--text_recognizer/training/experiments/image_transformer.yaml72
-rw-r--r--text_recognizer/training/run_experiment.py201
2 files changed, 273 insertions, 0 deletions
diff --git a/text_recognizer/training/experiments/image_transformer.yaml b/text_recognizer/training/experiments/image_transformer.yaml
new file mode 100644
index 0000000..bedcbb5
--- /dev/null
+++ b/text_recognizer/training/experiments/image_transformer.yaml
@@ -0,0 +1,72 @@
+seed: 4711
+
+network:
+ desc: null
+ type: ImageTransformer
+ args:
+ encoder:
+ type: null
+ args: null
+ num_decoder_layers: 4
+ hidden_dim: 256
+ num_heads: 4
+ expansion_dim: 1024
+ dropout_rate: 0.1
+ transformer_activation: glu
+
+model:
+ desc: null
+ type: LitTransformerModel
+ args:
+ optimizer:
+ type: MADGRAD
+ args:
+ lr: 1.0e-2
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: 512
+ criterion:
+ type: CrossEntropyLoss
+ args:
+ weight: None
+ ignore_index: -100
+ reduction: mean
+ monitor: val_loss
+ mapping: sentence_piece
+
+data:
+ desc: null
+ type: IAMExtendedParagraphs
+ args:
+ batch_size: 16
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
+
+callbacks:
+ - type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ - type: EarlyStopping
+ args:
+ monitor: val_loss
+ mode: min
+ patience: 10
+
+trainer:
+ desc: null
+ args:
+ stochastic_weight_avg: true
+ auto_scale_batch_size: binsearch
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 512
+ terminate_on_nan: true
+ weights_summary: true
diff --git a/text_recognizer/training/run_experiment.py b/text_recognizer/training/run_experiment.py
new file mode 100644
index 0000000..ed1a947
--- /dev/null
+++ b/text_recognizer/training/run_experiment.py
@@ -0,0 +1,201 @@
+"""Script to run experiments."""
+from datetime import datetime
+import importlib
+from pathlib import Path
+from typing import Dict, List, Optional, Type
+
+import click
+from loguru import logger
+from omegaconf import DictConfig, OmegaConf
+import pytorch_lightning as pl
+import torch
+from torch import nn
+from torchsummary import summary
+from tqdm import tqdm
+import wandb
+
+
+SEED = 4711
+EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
+
+
+def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
+ """Configure the loguru logger for output to terminal and disk."""
+
+ def _get_level(verbose: int) -> str:
+ """Sets the logger level."""
+ levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"}
+ verbose = min(verbose, 2)
+ return levels[verbose]
+
+ # Remove default logger to get tqdm to work properly.
+ logger.remove()
+
+ # Fetch verbosity level.
+ level = _get_level(verbose)
+
+ logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
+ if log_dir is not None:
+ logger.add(
+ str(log_dir / "train.log"),
+ format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
+ )
+
+
+def _load_config(file_path: Path) -> DictConfig:
+ """Return experiment config."""
+ logger.info(f"Loading config from: {file_path}")
+ if not file_path.exists():
+ raise FileNotFoundError(f"Experiment config not found at: {file_path}")
+ return OmegaConf.load(file_path)
+
+
+def _import_class(module_and_class_name: str) -> type:
+ """Import class from module."""
+ module_name, class_name = module_and_class_name.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ return getattr(module, class_name)
+
+
+def _configure_callbacks(
+ callbacks: List[DictConfig],
+) -> List[Type[pl.callbacks.Callback]]:
+ """Configures lightning callbacks."""
+ pl_callbacks = [
+ getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks
+ ]
+ return pl_callbacks
+
+
+def _configure_logger(
+ network: Type[nn.Module], args: Dict, use_wandb: bool
+) -> Type[pl.loggers.LightningLoggerBase]:
+ """Configures lightning logger."""
+ if use_wandb:
+ pl_logger = pl.loggers.WandbLogger()
+ pl_logger.watch(network)
+ pl_logger.log_hyperparams(vars(args))
+ return pl_logger
+ return pl.logger.TensorBoardLogger("training/logs")
+
+
+def _save_best_weights(
+ callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
+) -> None:
+ """Saves the best model."""
+ model_checkpoint_callback = next(
+ callback
+ for callback in callbacks
+ if isinstance(callback, pl.callbacks.ModelCheckpoint)
+ )
+ best_model_path = model_checkpoint_callback.best_model_path
+ if best_model_path:
+ logger.info(f"Best model saved at: {best_model_path}")
+ if use_wandb:
+ logger.info("Uploading model to W&B...")
+ wandb.save(best_model_path)
+
+
+def _load_lit_model(
+ lit_model_class: type, network: Type[nn.Module], config: DictConfig
+) -> Type[pl.LightningModule]:
+ """Load lightning model."""
+ if config.load_checkpoint is not None:
+ logger.info(
+ f"Loading network weights from checkpoint: {config.load_checkpoint}"
+ )
+ return lit_model_class.load_from_checkpoint(
+ config.load_checkpoint, network=network, **config.model.args
+ )
+ return lit_model_class(network=network, **config.model.args)
+
+
+def run(
+ filename: str,
+ train: bool,
+ test: bool,
+ tune: bool,
+ use_wandb: bool,
+ verbose: int = 0,
+) -> None:
+ """Runs experiment."""
+
+ _configure_logging(None, verbose=verbose)
+ logger.info("Starting experiment...")
+
+ # Seed everything in the experiment.
+ logger.info(f"Seeding everthing with seed={SEED}")
+ pl.utilities.seed.seed_everything(SEED)
+
+ # Load config.
+ file_path = EXPERIMENTS_DIRNAME / filename
+ config = _load_config(file_path)
+
+ # Load classes.
+ data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
+ network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
+ lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
+
+ # Initialize data object and network.
+ data_module = data_module_class(**config.data.args)
+ network = network_class(**data_module.config(), **config.network.args)
+
+ # Load callback and logger.
+ callbacks = _configure_callbacks(config.callbacks)
+ pl_logger = _configure_logger(network, config, use_wandb)
+
+ # Load ligtning model.
+ lit_model = _load_lit_model(lit_model_class, network, config)
+
+ trainer = pl.Trainer(
+ **config.trainer.args,
+ callbacks=callbacks,
+ logger=pl_logger,
+ weigths_save_path="training/logs",
+ )
+
+ if tune:
+ logger.info(f"Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
+ if train:
+ logger.info(f"Training network...")
+ trainer.fit(lit_model, datamodule=data_module)
+
+ if test:
+ logger.info(f"Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
+
+ _save_best_weights(callbacks, use_wandb)
+
+
+@click.command()
+@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
+@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
+@click.option(
+ "--tune", is_flag=True, help="If true, tune hyperparameters for training."
+)
+@click.option("--train", is_flag=True, help="If true, train the model.")
+@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-v", "--verbose", count=True)
+def cli(
+ experiment_config: str,
+ use_wandb: bool,
+ tune: bool,
+ train: bool,
+ test: bool,
+ verbose: int,
+) -> None:
+ """Run experiment."""
+ run(
+ filename=experiment_config,
+ train=train,
+ test=test,
+ tune=tune,
+ use_wandb=use_wandb,
+ verbose=verbose,
+ )
+
+
+if __name__ == "__main__":
+ cli()