summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/__init__.py (renamed from training/conf/callbacks/wandb/image_reconstructions.yaml)0
-rw-r--r--training/callbacks/wandb_callbacks.py83
-rw-r--r--training/conf/callbacks/checkpoint.yaml2
-rw-r--r--training/conf/callbacks/wandb_checkpoints.yaml (renamed from training/conf/callbacks/wandb/checkpoints.yaml)0
-rw-r--r--training/conf/callbacks/wandb_code.yaml (renamed from training/conf/callbacks/wandb/code.yaml)0
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml0
-rw-r--r--training/conf/callbacks/wandb_ocr.yaml8
-rw-r--r--training/conf/callbacks/wandb_ocr_predictions.yaml (renamed from training/conf/callbacks/wandb/ocr_predictions.yaml)0
-rw-r--r--training/conf/callbacks/wandb_watch.yaml (renamed from training/conf/callbacks/wandb/watch.yaml)0
-rw-r--r--training/conf/config.yaml21
-rw-r--r--training/conf/criterion/label_smoothing.yaml5
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml3
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml4
-rw-r--r--training/conf/mapping/word_piece.yaml4
-rw-r--r--training/conf/model/lit_transformer.yaml2
-rw-r--r--training/conf/network/conv_transformer.yaml1
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml1
-rw-r--r--training/run.py19
-rw-r--r--training/utils.py23
19 files changed, 92 insertions, 84 deletions
diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/__init__.py
index e69de29..e69de29 100644
--- a/training/conf/callbacks/wandb/image_reconstructions.yaml
+++ b/training/__init__.py
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 6379cc0..906531f 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -1,11 +1,10 @@
"""Weights and Biases callbacks."""
from pathlib import Path
-from typing import List
-import attr
import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
+from pytorch_lightning.utilities import rank_zero_only
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -22,31 +21,27 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
raise Exception("Weight and Biases logger not found for some reason...")
-@attr.s
class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
- log: str = attr.ib(default="gradients")
- log_freq: int = attr.ib(default=100)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, log: str = "gradients", log_freq: int = 100) -> None:
+ self.log = log
+ self.log_freq = log_freq
+ @rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
-@attr.s
class UploadCodeAsArtifact(Callback):
"""Upload all *.py files to W&B as an artifact, at the beginning of the run."""
- project_dir: Path = attr.ib(converter=Path)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, project_dir: str) -> None:
+ self.project_dir = Path(project_dir)
+ @rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads project code as an artifact."""
logger = get_wandb_logger(trainer)
@@ -58,16 +53,16 @@ class UploadCodeAsArtifact(Callback):
experiment.use_artifact(artifact)
-@attr.s
-class UploadCheckpointAsArtifact(Callback):
+class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoint to wandb as an artifact, at the end of a run."""
- ckpt_dir: Path = attr.ib(converter=Path)
- upload_best_only: bool = attr.ib()
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(
+ self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
+ ) -> None:
+ self.ckpt_dir = ckpt_dir
+ self.upload_best_only = upload_best_only
+ @rank_zero_only
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads model checkpoint to W&B."""
logger = get_wandb_logger(trainer)
@@ -83,15 +78,12 @@ class UploadCheckpointAsArtifact(Callback):
experiment.use_artifact(ckpts)
-@attr.s
class LogTextPredictions(Callback):
"""Logs a validation batch with image to text transcription."""
- num_samples: int = attr.ib(default=8)
- ready: bool = attr.ib(default=True)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, num_samples: int = 8) -> None:
+ self.num_samples = num_samples
+ self.ready = False
def _log_predictions(
self, stage: str, trainer: Trainer, pl_module: LightningModule
@@ -111,20 +103,20 @@ class LogTextPredictions(Callback):
logits = pl_module(imgs)
mapping = pl_module.mapping
+ columns = ["id", "image", "prediction", "truth"]
+ data = [
+ [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ for id, (img, pred, label) in enumerate(
+ zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ )
+ ]
+
experiment.log(
- {
- f"OCR/{experiment.name}/{stage}": [
- wandb.Image(
- img,
- caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
- )
- for img, pred, label in zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
- ]
- }
+ {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)}
)
def on_sanity_check_start(
@@ -143,20 +135,17 @@ class LogTextPredictions(Callback):
"""Logs predictions on validation epoch end."""
self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
-@attr.s
class LogReconstuctedImages(Callback):
"""Log reconstructions of images."""
- num_samples: int = attr.ib(default=8)
- ready: bool = attr.ib(default=True)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, num_samples: int = 8) -> None:
+ self.num_samples = num_samples
+ self.ready = False
def _log_reconstruction(
self, stage: str, trainer: Trainer, pl_module: LightningModule
@@ -202,6 +191,6 @@ class LogReconstuctedImages(Callback):
"""Logs predictions on validation epoch end."""
self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index db34cb1..b4101d8 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -6,4 +6,4 @@ model_checkpoint:
mode: min # can be "max" or "min"
verbose: false
dirpath: checkpoints/
- filename: {epoch:02d}
+ filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml
index a4a16ff..a4a16ff 100644
--- a/training/conf/callbacks/wandb/checkpoints.yaml
+++ b/training/conf/callbacks/wandb_checkpoints.yaml
diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb_code.yaml
index 35f6ea3..35f6ea3 100644
--- a/training/conf/callbacks/wandb/code.yaml
+++ b/training/conf/callbacks/wandb_code.yaml
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/conf/callbacks/wandb_image_reconstructions.yaml
diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml
index efa3dda..9c9a6da 100644
--- a/training/conf/callbacks/wandb_ocr.yaml
+++ b/training/conf/callbacks/wandb_ocr.yaml
@@ -1,6 +1,6 @@
defaults:
- default
- - wandb/watch
- - wandb/code
- - wandb/checkpoints
- - wandb/ocr_predictions
+ - wandb_watch
+ - wandb_code
+ - wandb_checkpoints
+ - wandb_ocr_predictions
diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml
index 573fa96..573fa96 100644
--- a/training/conf/callbacks/wandb/ocr_predictions.yaml
+++ b/training/conf/callbacks/wandb_ocr_predictions.yaml
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb_watch.yaml
index 511608c..511608c 100644
--- a/training/conf/callbacks/wandb/watch.yaml
+++ b/training/conf/callbacks/wandb_watch.yaml
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 93215ed..782bcbb 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,8 +1,9 @@
defaults:
- callbacks: wandb_ocr
- criterion: label_smoothing
- - dataset: iam_extended_paragraphs
+ - datamodule: iam_extended_paragraphs
- hydra: default
+ - logger: wandb
- lr_scheduler: one_cycle
- mapping: word_piece
- model: lit_transformer
@@ -15,3 +16,21 @@ tune: false
train: true
test: true
logging: INFO
+
+# path to original working directory
+# hydra hijacks working directory by changing it to the current log directory,
+# so it's useful to have this path as a special variable
+# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
+work_dir: ${hydra:runtime.cwd}
+
+# use `python run.py debug=true` for easy debugging!
+# this will run 1 train, val and test loop with only 1 batch
+# equivalent to running `python run.py trainer.fast_dev_run=true`
+# (this is placed here just for easier access from command line)
+debug: False
+
+# pretty print config at the start of the run using Rich library
+print_config: True
+
+# disable python warnings if they annoy you
+ignore_warnings: True
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
index 13daba8..684b5bb 100644
--- a/training/conf/criterion/label_smoothing.yaml
+++ b/training/conf/criterion/label_smoothing.yaml
@@ -1,4 +1,3 @@
-_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
-label_smoothing: 0.1
-vocab_size: 1006
+_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
+smoothing: 0.1
ignore_index: 1002
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index 3070b56..2d1a03e 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -1,5 +1,6 @@
_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
-batch_size: 32
+batch_size: 4
num_workers: 12
train_fraction: 0.8
augment: true
+pin_memory: false
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index 5afdf81..eecee8a 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,8 +1,8 @@
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 1.0e-3
total_steps: null
-epochs: null
-steps_per_epoch: null
+epochs: 512
+steps_per_epoch: 4992
pct_start: 0.3
anneal_strategy: cos
cycle_momentum: true
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index 3792523..48384f5 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.data.mappings.WordPieceMapping
+_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
num_features: 1000
tokens: iamdb_1kwp_tokens_1000.txt
lexicon: iamdb_1kwp_lex_1000.txt
@@ -6,4 +6,4 @@ data_dir: null
use_words: false
prepend_wordsep: false
special_tokens: [ <s>, <e>, <p> ]
-extra_symbols: [ \n ]
+extra_symbols: [ "\n" ]
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 6ffde4e..c190151 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,7 +1,7 @@
_target_: text_recognizer.models.transformer.TransformerLitModel
interval: step
monitor: val/loss
-ignore_tokens: [ <s>, <e>, <p> ]
+max_output_len: 451
start_token: <s>
end_token: <e>
pad_token: <p>
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index a97157d..f76e892 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -6,6 +6,5 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
hidden_dim: 96
dropout_rate: 0.2
-max_output_len: 451
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 90b9d8a..eb80f64 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -18,3 +18,4 @@ ff_kwargs:
dropout_rate: 0.2
cross_attend: true
pre_norm: true
+rotary_emb: null
diff --git a/training/run.py b/training/run.py
index 30479c6..13a6a82 100644
--- a/training/run.py
+++ b/training/run.py
@@ -12,35 +12,40 @@ from pytorch_lightning import (
Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
-from text_recognizer.data.mappings import AbstractMapping
from torch import nn
+from text_recognizer.data.base_mapping import AbstractMapping
import utils
def run(config: DictConfig) -> Optional[float]:
"""Runs experiment."""
- utils.configure_logging(config.logging)
+ utils.configure_logging(config)
log.info("Starting experiment...")
if config.get("seed"):
- seed_everything(config.seed)
+ seed_everything(config.seed, workers=True)
log.info(f"Instantiating mapping <{config.mapping._target_}>")
mapping: AbstractMapping = hydra.utils.instantiate(config.mapping)
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)
+ datamodule: LightningDataModule = hydra.utils.instantiate(
+ config.datamodule, mapping=mapping
+ )
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network)
+ log.info(f"Instantiating criterion <{config.criterion._target_}>")
+ loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
+
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
- **config.model,
+ config.model,
mapping=mapping,
network=network,
- criterion_config=config.criterion,
+ loss_fn=loss_fn,
optimizer_config=config.optimizer,
lr_scheduler_config=config.lr_scheduler,
_recursive_=False,
@@ -77,4 +82,4 @@ def run(config: DictConfig) -> Optional[float]:
trainer.test(model, datamodule=datamodule)
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
- utils.finish(trainer)
+ utils.finish(logger)
diff --git a/training/utils.py b/training/utils.py
index ef74f61..d23396e 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -17,6 +17,10 @@ from tqdm import tqdm
import wandb
+def print_config(config: DictConfig) -> None:
+ print(OmegaConf.to_yaml(config))
+
+
@rank_zero_only
def configure_logging(config: DictConfig) -> None:
"""Configure the loguru logger for output to terminal and disk."""
@@ -30,7 +34,7 @@ def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
callbacks = []
if config.get("callbacks"):
for callback_config in config.callbacks.values():
- if config.get("_target_"):
+ if callback_config.get("_target_"):
log.info(f"Instantiating callback <{callback_config._target_}>")
callbacks.append(hydra.utils.instantiate(callback_config))
return callbacks
@@ -41,8 +45,8 @@ def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]:
logger = []
if config.get("logger"):
for logger_config in config.logger.values():
- if config.get("_target_"):
- log.info(f"Instantiating callback <{logger_config._target_}>")
+ if logger_config.get("_target_"):
+ log.info(f"Instantiating logger <{logger_config._target_}>")
logger.append(hydra.utils.instantiate(logger_config))
return logger
@@ -67,17 +71,8 @@ def extras(config: DictConfig) -> None:
# Debuggers do not like GPUs and multiprocessing.
if config.trainer.get("gpus"):
config.trainer.gpus = 0
- if config.datamodule.get("pin_memory"):
- config.datamodule.pin_memory = False
- if config.datamodule.get("num_workers"):
- config.datamodule.num_workers = 0
-
- # Force multi-gpu friendly config.
- accelerator = config.trainer.get("accelerator")
- if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]:
- log.info(
- f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>"
- )
+ if config.trainer.get("precision"):
+ config.trainer.precision = 32
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):