diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 4 | ||||
-rw-r--r-- | training/conf/datamodule/emnist_lines.yaml | 2 | ||||
-rw-r--r-- | training/conf/datamodule/iam_extended_paragraphs.yaml | 4 | ||||
-rw-r--r-- | training/conf/datamodule/iam_lines.yaml | 4 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 3 | ||||
-rw-r--r-- | training/conf/mapping/characters.yaml | 2 | ||||
-rw-r--r-- | training/conf/tokenizer/default.yaml | 2 | ||||
-rw-r--r-- | training/run.py | 3 |
8 files changed, 12 insertions, 12 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index dc59f19..7d4f48e 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -117,9 +117,9 @@ class LogTextPredictions(Callback): imgs = imgs.to(device=pl_module.device) logits = pl_module.predict(imgs) - mapping = pl_module.mapping + tokenizer = pl_module.tokenizer data = [ - wandb.Image(img, caption=mapping.get_text(pred)) + wandb.Image(img, caption=tokenizer.decode(pred)) for img, pred, label in zip( imgs[: self.num_samples], logits[: self.num_samples], diff --git a/training/conf/datamodule/emnist_lines.yaml b/training/conf/datamodule/emnist_lines.yaml index 218df6c..ce35c3e 100644 --- a/training/conf/datamodule/emnist_lines.yaml +++ b/training/conf/datamodule/emnist_lines.yaml @@ -6,4 +6,4 @@ pin_memory: true transform: transform/lines.yaml test_transform: test_transform/lines.yaml mapping: - _target_: text_recognizer.data.mappings.EmnistMapping + _target_: text_recognizer.data.tokenizer.Tokenizer diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index c46714c..64c3964 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -13,6 +13,6 @@ target_transform: _target_: text_recognizer.data.transforms.pad.Pad max_len: 682 pad_index: 3 -mapping: - _target_: text_recognizer.data.mappings.EmnistMapping +tokenizer: + _target_: text_recognizer.data.tokenizer.Tokenizer extra_symbols: ["\n"] diff --git a/training/conf/datamodule/iam_lines.yaml b/training/conf/datamodule/iam_lines.yaml index 4f1f1b8..f84116d 100644 --- a/training/conf/datamodule/iam_lines.yaml +++ b/training/conf/datamodule/iam_lines.yaml @@ -9,5 +9,5 @@ transform: test_transform: _target_: text_recognizer.data.stems.line.IamLinesStem augment: false -mapping: - _target_: text_recognizer.data.mappings.EmnistMapping +tokenizer: + _target_: text_recognizer.data.tokenizer.Tokenizer diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 60ff1bf..cdac387 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -84,8 +84,9 @@ network: decoder: _target_: text_recognizer.networks.transformer.Decoder depth: 6 + dim: *hidden_dim block: - _target_: text_recognizer.networks.transformer.DecoderBlock + _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock self_attn: _target_: text_recognizer.networks.transformer.Attention dim: *hidden_dim diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml deleted file mode 100644 index 8cbd55d..0000000 --- a/training/conf/mapping/characters.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: text_recognizer.data.mappings.EmnistMapping -extra_symbols: [ "\n" ] diff --git a/training/conf/tokenizer/default.yaml b/training/conf/tokenizer/default.yaml new file mode 100644 index 0000000..2b1a8c9 --- /dev/null +++ b/training/conf/tokenizer/default.yaml @@ -0,0 +1,2 @@ +_target_: text_recognizer.data.tokenizer.Tokenizer +extra_symbols: ["\n"] diff --git a/training/run.py b/training/run.py index 68cedc7..99059d6 100644 --- a/training/run.py +++ b/training/run.py @@ -14,7 +14,6 @@ from pytorch_lightning import ( from pytorch_lightning.loggers import LightningLoggerBase from torch import nn from torchinfo import summary - import utils @@ -39,7 +38,7 @@ def run(config: DictConfig) -> Optional[float]: model: LightningModule = hydra.utils.instantiate( config.model, network=network, - mapping=datamodule.mapping, + tokenizer=datamodule.tokenizer, loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, |