diff options
Diffstat (limited to 'training')
19 files changed, 92 insertions, 84 deletions
diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/__init__.py index e69de29..e69de29 100644 --- a/training/conf/callbacks/wandb/image_reconstructions.yaml +++ b/training/__init__.py diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 6379cc0..906531f 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -1,11 +1,10 @@ """Weights and Biases callbacks.""" from pathlib import Path -from typing import List -import attr import wandb from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.utilities import rank_zero_only def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -22,31 +21,27 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: raise Exception("Weight and Biases logger not found for some reason...") -@attr.s class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" - log: str = attr.ib(default="gradients") - log_freq: int = attr.ib(default=100) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, log: str = "gradients", log_freq: int = 100) -> None: + self.log = log + self.log_freq = log_freq + @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Watches model weights with wandb.""" logger = get_wandb_logger(trainer) logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) -@attr.s class UploadCodeAsArtifact(Callback): """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" - project_dir: Path = attr.ib(converter=Path) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, project_dir: str) -> None: + self.project_dir = Path(project_dir) + @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Uploads project code as an artifact.""" logger = get_wandb_logger(trainer) @@ -58,16 +53,16 @@ class UploadCodeAsArtifact(Callback): experiment.use_artifact(artifact) -@attr.s -class UploadCheckpointAsArtifact(Callback): +class UploadCheckpointsAsArtifact(Callback): """Upload checkpoint to wandb as an artifact, at the end of a run.""" - ckpt_dir: Path = attr.ib(converter=Path) - upload_best_only: bool = attr.ib() - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__( + self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False + ) -> None: + self.ckpt_dir = ckpt_dir + self.upload_best_only = upload_best_only + @rank_zero_only def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Uploads model checkpoint to W&B.""" logger = get_wandb_logger(trainer) @@ -83,15 +78,12 @@ class UploadCheckpointAsArtifact(Callback): experiment.use_artifact(ckpts) -@attr.s class LogTextPredictions(Callback): """Logs a validation batch with image to text transcription.""" - num_samples: int = attr.ib(default=8) - ready: bool = attr.ib(default=True) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, num_samples: int = 8) -> None: + self.num_samples = num_samples + self.ready = False def _log_predictions( self, stage: str, trainer: Trainer, pl_module: LightningModule @@ -111,20 +103,20 @@ class LogTextPredictions(Callback): logits = pl_module(imgs) mapping = pl_module.mapping + columns = ["id", "image", "prediction", "truth"] + data = [ + [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] + for id, (img, pred, label) in enumerate( + zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], + ) + ) + ] + experiment.log( - { - f"OCR/{experiment.name}/{stage}": [ - wandb.Image( - img, - caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", - ) - for img, pred, label in zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) - ] - } + {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} ) def on_sanity_check_start( @@ -143,20 +135,17 @@ class LogTextPredictions(Callback): """Logs predictions on validation epoch end.""" self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) -@attr.s class LogReconstuctedImages(Callback): """Log reconstructions of images.""" - num_samples: int = attr.ib(default=8) - ready: bool = attr.ib(default=True) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, num_samples: int = 8) -> None: + self.num_samples = num_samples + self.ready = False def _log_reconstruction( self, stage: str, trainer: Trainer, pl_module: LightningModule @@ -202,6 +191,6 @@ class LogReconstuctedImages(Callback): """Logs predictions on validation epoch end.""" self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml index db34cb1..b4101d8 100644 --- a/training/conf/callbacks/checkpoint.yaml +++ b/training/conf/callbacks/checkpoint.yaml @@ -6,4 +6,4 @@ model_checkpoint: mode: min # can be "max" or "min" verbose: false dirpath: checkpoints/ - filename: {epoch:02d} + filename: "{epoch:02d}" diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml index a4a16ff..a4a16ff 100644 --- a/training/conf/callbacks/wandb/checkpoints.yaml +++ b/training/conf/callbacks/wandb_checkpoints.yaml diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb_code.yaml index 35f6ea3..35f6ea3 100644 --- a/training/conf/callbacks/wandb/code.yaml +++ b/training/conf/callbacks/wandb_code.yaml diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/training/conf/callbacks/wandb_image_reconstructions.yaml diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml index efa3dda..9c9a6da 100644 --- a/training/conf/callbacks/wandb_ocr.yaml +++ b/training/conf/callbacks/wandb_ocr.yaml @@ -1,6 +1,6 @@ defaults: - default - - wandb/watch - - wandb/code - - wandb/checkpoints - - wandb/ocr_predictions + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_ocr_predictions diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml index 573fa96..573fa96 100644 --- a/training/conf/callbacks/wandb/ocr_predictions.yaml +++ b/training/conf/callbacks/wandb_ocr_predictions.yaml diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb_watch.yaml index 511608c..511608c 100644 --- a/training/conf/callbacks/wandb/watch.yaml +++ b/training/conf/callbacks/wandb_watch.yaml diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 93215ed..782bcbb 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,8 +1,9 @@ defaults: - callbacks: wandb_ocr - criterion: label_smoothing - - dataset: iam_extended_paragraphs + - datamodule: iam_extended_paragraphs - hydra: default + - logger: wandb - lr_scheduler: one_cycle - mapping: word_piece - model: lit_transformer @@ -15,3 +16,21 @@ tune: false train: true test: true logging: INFO + +# path to original working directory +# hydra hijacks working directory by changing it to the current log directory, +# so it's useful to have this path as a special variable +# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory +work_dir: ${hydra:runtime.cwd} + +# use `python run.py debug=true` for easy debugging! +# this will run 1 train, val and test loop with only 1 batch +# equivalent to running `python run.py trainer.fast_dev_run=true` +# (this is placed here just for easier access from command line) +debug: False + +# pretty print config at the start of the run using Rich library +print_config: True + +# disable python warnings if they annoy you +ignore_warnings: True diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index 13daba8..684b5bb 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -1,4 +1,3 @@ -_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss -label_smoothing: 0.1 -vocab_size: 1006 +_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss +smoothing: 0.1 ignore_index: 1002 diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index 3070b56..2d1a03e 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -1,5 +1,6 @@ _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs -batch_size: 32 +batch_size: 4 num_workers: 12 train_fraction: 0.8 augment: true +pin_memory: false diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index 5afdf81..eecee8a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,8 +1,8 @@ _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 1.0e-3 total_steps: null -epochs: null -steps_per_epoch: null +epochs: 512 +steps_per_epoch: 4992 pct_start: 0.3 anneal_strategy: cos cycle_momentum: true diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml index 3792523..48384f5 100644 --- a/training/conf/mapping/word_piece.yaml +++ b/training/conf/mapping/word_piece.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.data.mappings.WordPieceMapping +_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping num_features: 1000 tokens: iamdb_1kwp_tokens_1000.txt lexicon: iamdb_1kwp_lex_1000.txt @@ -6,4 +6,4 @@ data_dir: null use_words: false prepend_wordsep: false special_tokens: [ <s>, <e>, <p> ] -extra_symbols: [ \n ] +extra_symbols: [ "\n" ] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index 6ffde4e..c190151 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,7 +1,7 @@ _target_: text_recognizer.models.transformer.TransformerLitModel interval: step monitor: val/loss -ignore_tokens: [ <s>, <e>, <p> ] +max_output_len: 451 start_token: <s> end_token: <e> pad_token: <p> diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index a97157d..f76e892 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -6,6 +6,5 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: 96 dropout_rate: 0.2 -max_output_len: 451 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 90b9d8a..eb80f64 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -18,3 +18,4 @@ ff_kwargs: dropout_rate: 0.2 cross_attend: true pre_norm: true +rotary_emb: null diff --git a/training/run.py b/training/run.py index 30479c6..13a6a82 100644 --- a/training/run.py +++ b/training/run.py @@ -12,35 +12,40 @@ from pytorch_lightning import ( Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase -from text_recognizer.data.mappings import AbstractMapping from torch import nn +from text_recognizer.data.base_mapping import AbstractMapping import utils def run(config: DictConfig) -> Optional[float]: """Runs experiment.""" - utils.configure_logging(config.logging) + utils.configure_logging(config) log.info("Starting experiment...") if config.get("seed"): - seed_everything(config.seed) + seed_everything(config.seed, workers=True) log.info(f"Instantiating mapping <{config.mapping._target_}>") mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) log.info(f"Instantiating datamodule <{config.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping) + datamodule: LightningDataModule = hydra.utils.instantiate( + config.datamodule, mapping=mapping + ) log.info(f"Instantiating network <{config.network._target_}>") network: nn.Module = hydra.utils.instantiate(config.network) + log.info(f"Instantiating criterion <{config.criterion._target_}>") + loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) + log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( - **config.model, + config.model, mapping=mapping, network=network, - criterion_config=config.criterion, + loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, _recursive_=False, @@ -77,4 +82,4 @@ def run(config: DictConfig) -> Optional[float]: trainer.test(model, datamodule=datamodule) log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") - utils.finish(trainer) + utils.finish(logger) diff --git a/training/utils.py b/training/utils.py index ef74f61..d23396e 100644 --- a/training/utils.py +++ b/training/utils.py @@ -17,6 +17,10 @@ from tqdm import tqdm import wandb +def print_config(config: DictConfig) -> None: + print(OmegaConf.to_yaml(config)) + + @rank_zero_only def configure_logging(config: DictConfig) -> None: """Configure the loguru logger for output to terminal and disk.""" @@ -30,7 +34,7 @@ def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]: callbacks = [] if config.get("callbacks"): for callback_config in config.callbacks.values(): - if config.get("_target_"): + if callback_config.get("_target_"): log.info(f"Instantiating callback <{callback_config._target_}>") callbacks.append(hydra.utils.instantiate(callback_config)) return callbacks @@ -41,8 +45,8 @@ def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]: logger = [] if config.get("logger"): for logger_config in config.logger.values(): - if config.get("_target_"): - log.info(f"Instantiating callback <{logger_config._target_}>") + if logger_config.get("_target_"): + log.info(f"Instantiating logger <{logger_config._target_}>") logger.append(hydra.utils.instantiate(logger_config)) return logger @@ -67,17 +71,8 @@ def extras(config: DictConfig) -> None: # Debuggers do not like GPUs and multiprocessing. if config.trainer.get("gpus"): config.trainer.gpus = 0 - if config.datamodule.get("pin_memory"): - config.datamodule.pin_memory = False - if config.datamodule.get("num_workers"): - config.datamodule.num_workers = 0 - - # Force multi-gpu friendly config. - accelerator = config.trainer.get("accelerator") - if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: - log.info( - f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>" - ) + if config.trainer.get("precision"): + config.trainer.precision = 32 if config.datamodule.get("pin_memory"): config.datamodule.pin_memory = False if config.datamodule.get("num_workers"): |