summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py8
-rw-r--r--training/conf/callbacks/wandb_code.yaml1
-rw-r--r--training/conf/callbacks/wandb_htr.yaml2
-rw-r--r--training/conf/callbacks/wandb_vae.yaml2
-rw-r--r--training/conf/experiment/htr_char.yaml7
-rw-r--r--training/conf/experiment/vqvae.yaml8
-rw-r--r--training/conf/model/lit_vqvae.yaml2
-rw-r--r--training/conf/network/decoder/pixelcnn_encoder.yaml5
-rw-r--r--training/conf/network/decoder/vae_decoder.yaml5
-rw-r--r--training/conf/network/encoder/pixelcnn_decoder.yaml5
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml5
-rw-r--r--training/conf/network/vqvae.yaml15
-rw-r--r--training/conf/network/vqvae_pixelcnn.yaml9
-rw-r--r--training/conf/optimizer/madgrad.yaml2
-rw-r--r--training/conf/trainer/default.yaml2
-rw-r--r--training/run.py4
16 files changed, 61 insertions, 21 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 61d71df..2264750 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -39,8 +39,8 @@ class WatchModel(Callback):
class UploadCodeAsArtifact(Callback):
"""Upload all *.py files to W&B as an artifact, at the beginning of the run."""
- def __init__(self, project_dir: str) -> None:
- self.project_dir = Path(project_dir)
+ def __init__(self) -> None:
+ self.project_dir = Path(__file__).resolve().parents[2] / "text_recognizer"
@rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -49,7 +49,7 @@ class UploadCodeAsArtifact(Callback):
experiment = logger.experiment
artifact = wandb.Artifact("project-source", type="code")
for filepath in self.project_dir.glob("**/*.py"):
- artifact.add_file(filepath)
+ artifact.add_file(str(filepath))
experiment.use_artifact(artifact)
@@ -60,7 +60,7 @@ class UploadCheckpointsAsArtifact(Callback):
def __init__(
self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
) -> None:
- self.ckpt_dir = ckpt_dir
+ self.ckpt_dir = Path(__file__).resolve().parent / ckpt_dir
self.upload_best_only = upload_best_only
@rank_zero_only
diff --git a/training/conf/callbacks/wandb_code.yaml b/training/conf/callbacks/wandb_code.yaml
index 35f6ea3..012cdce 100644
--- a/training/conf/callbacks/wandb_code.yaml
+++ b/training/conf/callbacks/wandb_code.yaml
@@ -1,3 +1,2 @@
upload_code_as_artifact:
_target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
- project_dir: ${work_dir}/text_recognizer
diff --git a/training/conf/callbacks/wandb_htr.yaml b/training/conf/callbacks/wandb_htr.yaml
index 9c9a6da..44adb71 100644
--- a/training/conf/callbacks/wandb_htr.yaml
+++ b/training/conf/callbacks/wandb_htr.yaml
@@ -3,4 +3,4 @@ defaults:
- wandb_watch
- wandb_code
- wandb_checkpoints
- - wandb_ocr_predictions
+ - wandb_htr_predictions
diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml
index 609a8e8..c7b09b0 100644
--- a/training/conf/callbacks/wandb_vae.yaml
+++ b/training/conf/callbacks/wandb_vae.yaml
@@ -1,6 +1,6 @@
defaults:
- default
- wandb_watch
- - wandb_code
- wandb_checkpoints
- wandb_image_reconstructions
+ # - wandb_code
diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml
index 77126ae..e51a116 100644
--- a/training/conf/experiment/htr_char.yaml
+++ b/training/conf/experiment/htr_char.yaml
@@ -3,10 +3,15 @@
defaults:
- override /mapping: characters
+datamodule:
+ word_pieces: false
+
criterion:
ignore_index: 3
network:
- num_classes: 89
+ num_classes: 58
pad_index: 3
+
+model:
max_output_len: 682
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index 699612e..eb40f3b 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -8,14 +8,16 @@ defaults:
trainer:
max_epochs: 64
- gradient_clip_val: 0.25
+ # gradient_clip_val: 0.25
datamodule:
- batch_size: 32
+ batch_size: 16
lr_scheduler:
epochs: 64
- steps_per_epoch: 624
+ steps_per_epoch: 1245
optimizer:
lr: 1.0e-3
+
+summary: [1, 576, 640]
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 8837573..409fa0d 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,4 +1,4 @@
_target_: text_recognizer.models.vqvae.VQVAELitModel
interval: step
monitor: val/loss
-latent_loss_weight: 0.25
+latent_loss_weight: 1.0
diff --git a/training/conf/network/decoder/pixelcnn_encoder.yaml b/training/conf/network/decoder/pixelcnn_encoder.yaml
new file mode 100644
index 0000000..47a130d
--- /dev/null
+++ b/training/conf/network/decoder/pixelcnn_encoder.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.networks.vqvae.pixelcnn.Encoder
+in_channels: 1
+hidden_dim: 8
+channels_multipliers: [1, 2, 8, 8]
+dropout_rate: 0.25
diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml
new file mode 100644
index 0000000..b2090b3
--- /dev/null
+++ b/training/conf/network/decoder/vae_decoder.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.networks.vqvae.decoder.Decoder
+out_channels: 1
+hidden_dim: 32
+channels_multipliers: [8, 6, 2, 1]
+dropout_rate: 0.25
diff --git a/training/conf/network/encoder/pixelcnn_decoder.yaml b/training/conf/network/encoder/pixelcnn_decoder.yaml
new file mode 100644
index 0000000..3895164
--- /dev/null
+++ b/training/conf/network/encoder/pixelcnn_decoder.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.networks.vqvae.pixelcnn.Decoder
+out_channels: 1
+hidden_dim: 8
+channels_multipliers: [8, 8, 2, 1]
+dropout_rate: 0.25
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
new file mode 100644
index 0000000..5dc6814
--- /dev/null
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.networks.vqvae.encoder.Encoder
+in_channels: 1
+hidden_dim: 32
+channels_multipliers: [1, 2, 6, 8]
+dropout_rate: 0.25
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 5a5c066..835d0b7 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -1,8 +1,9 @@
-_target_: text_recognizer.networks.vqvae.VQVAE
-in_channels: 1
-res_channels: 32
-num_residual_layers: 2
-embedding_dim: 64
-num_embeddings: 512
+defaults:
+ - encoder: vae_encoder
+ - decoder: vae_decoder
+
+_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
+hidden_dim: 256
+embedding_dim: 32
+num_embeddings: 1024
decay: 0.99
-activation: mish
diff --git a/training/conf/network/vqvae_pixelcnn.yaml b/training/conf/network/vqvae_pixelcnn.yaml
new file mode 100644
index 0000000..10200bc
--- /dev/null
+++ b/training/conf/network/vqvae_pixelcnn.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - encoder: pixelcnn_encoder
+ - decoder: pixelcnn_decoder
+
+_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
+hidden_dim: 64
+embedding_dim: 32
+num_embeddings: 512
+decay: 0.99
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml
index 84626d3..46b2fff 100644
--- a/training/conf/optimizer/madgrad.yaml
+++ b/training/conf/optimizer/madgrad.yaml
@@ -1,5 +1,5 @@
_target_: madgrad.MADGRAD
-lr: 1.0e-3
+lr: 2.0e-4
momentum: 0.9
weight_decay: 0
eps: 1.0e-6
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index 0fa9ce1..c665adc 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -8,7 +8,7 @@ gpus: 1
precision: 16
max_epochs: 512
terminate_on_nan: true
-weights_summary: full
+weights_summary: top
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
diff --git a/training/run.py b/training/run.py
index 13a6a82..a2529b0 100644
--- a/training/run.py
+++ b/training/run.py
@@ -13,6 +13,7 @@ from pytorch_lightning import (
)
from pytorch_lightning.loggers import LightningLoggerBase
from torch import nn
+from torchsummary import summary
from text_recognizer.data.base_mapping import AbstractMapping
import utils
@@ -37,6 +38,9 @@ def run(config: DictConfig) -> Optional[float]:
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network)
+ if config.summary:
+ summary(network, tuple(config.summary), device="cpu")
+
log.info(f"Instantiating criterion <{config.criterion._target_}>")
loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)