summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /training
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'training')
-rw-r--r--training/conf/callbacks/wandb.yaml20
-rw-r--r--training/conf/callbacks/wandb/checkpoints.yaml4
-rw-r--r--training/conf/callbacks/wandb/code.yaml3
-rw-r--r--training/conf/callbacks/wandb/image_reconstructions.yaml0
-rw-r--r--training/conf/callbacks/wandb/ocr_predictions.yaml3
-rw-r--r--training/conf/callbacks/wandb/watch.yaml4
-rw-r--r--training/conf/callbacks/wandb_ocr.yaml6
-rw-r--r--training/conf/config.yaml18
-rw-r--r--training/conf/criterion/label_smoothing.yaml2
-rw-r--r--training/conf/hydra/default.yaml6
-rw-r--r--training/conf/mapping/word_piece.yaml (renamed from training/conf/model/mapping/word_piece.yaml)4
-rw-r--r--training/conf/model/lit_transformer.yaml5
-rw-r--r--training/conf/network/conv_transformer.yaml2
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml4
-rw-r--r--training/conf/trainer/default.yaml6
-rw-r--r--training/run.py11
-rw-r--r--training/utils.py2
17 files changed, 55 insertions, 45 deletions
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml
deleted file mode 100644
index 0017e11..0000000
--- a/training/conf/callbacks/wandb.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-defaults:
- - default.yaml
-
-watch_model:
- _target_: callbacks.wandb_callbacks.WatchModel
- log: all
- log_freq: 100
-
-upload_code_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
- project_dir: ${work_dir}/text_recognizer
-
-upload_ckpts_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
- ckpt_dir: checkpoints/
- upload_best_only: true
-
-log_text_predictions:
- _target_: callbacks.wandb_callbacks.LogTextPredictions
- num_samples: 8
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml
new file mode 100644
index 0000000..a4a16ff
--- /dev/null
+++ b/training/conf/callbacks/wandb/checkpoints.yaml
@@ -0,0 +1,4 @@
+upload_ckpts_as_artifact:
+ _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+ ckpt_dir: checkpoints/
+ upload_best_only: true
diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb/code.yaml
new file mode 100644
index 0000000..35f6ea3
--- /dev/null
+++ b/training/conf/callbacks/wandb/code.yaml
@@ -0,0 +1,3 @@
+upload_code_as_artifact:
+ _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
+ project_dir: ${work_dir}/text_recognizer
diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb/image_reconstructions.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/conf/callbacks/wandb/image_reconstructions.yaml
diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb/ocr_predictions.yaml
new file mode 100644
index 0000000..573fa96
--- /dev/null
+++ b/training/conf/callbacks/wandb/ocr_predictions.yaml
@@ -0,0 +1,3 @@
+log_text_predictions:
+ _target_: callbacks.wandb_callbacks.LogTextPredictions
+ num_samples: 8
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml
new file mode 100644
index 0000000..511608c
--- /dev/null
+++ b/training/conf/callbacks/wandb/watch.yaml
@@ -0,0 +1,4 @@
+watch_model:
+ _target_: callbacks.wandb_callbacks.WatchModel
+ log: all
+ log_freq: 100
diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml
new file mode 100644
index 0000000..efa3dda
--- /dev/null
+++ b/training/conf/callbacks/wandb_ocr.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb/watch
+ - wandb/code
+ - wandb/checkpoints
+ - wandb/ocr_predictions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index a8e718e..93215ed 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,19 +1,17 @@
defaults:
- - network: vqvae
- - criterion: mse
- - optimizer: madgrad
- - lr_scheduler: one_cycle
- - model: lit_vqvae
+ - callbacks: wandb_ocr
+ - criterion: label_smoothing
- dataset: iam_extended_paragraphs
+ - hydra: default
+ - lr_scheduler: one_cycle
+ - mapping: word_piece
+ - model: lit_transformer
+ - network: conv_transformer
+ - optimizer: madgrad
- trainer: default
- - callbacks:
- - checkpoint
- - learning_rate_monitor
seed: 4711
-wandb: false
tune: false
train: true
test: true
-load_checkpoint: null
logging: INFO
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
index ee47c59..13daba8 100644
--- a/training/conf/criterion/label_smoothing.yaml
+++ b/training/conf/criterion/label_smoothing.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.criterion.label_smoothing
+_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
label_smoothing: 0.1
vocab_size: 1006
ignore_index: 1002
diff --git a/training/conf/hydra/default.yaml b/training/conf/hydra/default.yaml
new file mode 100644
index 0000000..dfd9721
--- /dev/null
+++ b/training/conf/hydra/default.yaml
@@ -0,0 +1,6 @@
+# output paths for hydra logs
+run:
+ dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+sweep:
+ dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
+ subdir: ${hydra.job.num}
diff --git a/training/conf/model/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index 39e2ba4..3792523 100644
--- a/training/conf/model/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -5,5 +5,5 @@ lexicon: iamdb_1kwp_lex_1000.txt
data_dir: null
use_words: false
prepend_wordsep: false
-special_tokens: ["<s>", "<e>", "<p>"]
-extra_symbols: ["\n"]
+special_tokens: [ <s>, <e>, <p> ]
+extra_symbols: [ \n ]
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 5341d8e..6ffde4e 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,8 +1,5 @@
-defaults:
- - mapping: word_piece
-
_target_: text_recognizer.models.transformer.TransformerLitModel
-interval: null
+interval: step
monitor: val/loss
ignore_tokens: [ <s>, <e>, <p> ]
start_token: <s>
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index 7d57a2d..a97157d 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -4,7 +4,7 @@ defaults:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
-hidden_dim: 256
+hidden_dim: 96
dropout_rate: 0.2
max_output_len: 451
num_classes: 1006
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 3122de1..90b9d8a 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -2,12 +2,12 @@ defaults:
- rotary_emb: null
_target_: text_recognizer.networks.transformer.Decoder
-dim: 256
+dim: 96
depth: 2
num_heads: 8
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
- dim_head: 64
+ dim_head: 16
dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index 5ed6552..c665adc 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -6,6 +6,10 @@ gradient_clip_val: 0
fast_dev_run: false
gpus: 1
precision: 16
-max_epochs: 64
+max_epochs: 512
terminate_on_nan: true
weights_summary: top
+limit_train_batches: 1.0
+limit_val_batches: 1.0
+limit_test_batches: 1.0
+resume_from_checkpoint: null
diff --git a/training/run.py b/training/run.py
index d88a8f6..30479c6 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,7 +2,7 @@
from typing import List, Optional, Type
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
@@ -12,6 +12,7 @@ from pytorch_lightning import (
Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
+from text_recognizer.data.mappings import AbstractMapping
from torch import nn
import utils
@@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]:
if config.get("seed"):
seed_everything(config.seed)
+ log.info(f"Instantiating mapping <{config.mapping._target_}>")
+ mapping: AbstractMapping = hydra.utils.instantiate(config.mapping)
+
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
+ datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)
log.info(f"Instantiating network <{config.network._target_}>")
- network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config())
+ network: nn.Module = hydra.utils.instantiate(config.network)
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
**config.model,
+ mapping=mapping,
network=network,
criterion_config=config.criterion,
optimizer_config=config.optimizer,
diff --git a/training/utils.py b/training/utils.py
index 564b9bb..ef74f61 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -3,7 +3,7 @@ from typing import Any, List, Type
import warnings
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import (
Callback,