summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py4
-rw-r--r--training/conf/datamodule/emnist_lines.yaml2
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml4
-rw-r--r--training/conf/datamodule/iam_lines.yaml4
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml3
-rw-r--r--training/conf/mapping/characters.yaml2
-rw-r--r--training/conf/tokenizer/default.yaml2
-rw-r--r--training/run.py3
8 files changed, 12 insertions, 12 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index dc59f19..7d4f48e 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -117,9 +117,9 @@ class LogTextPredictions(Callback):
imgs = imgs.to(device=pl_module.device)
logits = pl_module.predict(imgs)
- mapping = pl_module.mapping
+ tokenizer = pl_module.tokenizer
data = [
- wandb.Image(img, caption=mapping.get_text(pred))
+ wandb.Image(img, caption=tokenizer.decode(pred))
for img, pred, label in zip(
imgs[: self.num_samples],
logits[: self.num_samples],
diff --git a/training/conf/datamodule/emnist_lines.yaml b/training/conf/datamodule/emnist_lines.yaml
index 218df6c..ce35c3e 100644
--- a/training/conf/datamodule/emnist_lines.yaml
+++ b/training/conf/datamodule/emnist_lines.yaml
@@ -6,4 +6,4 @@ pin_memory: true
transform: transform/lines.yaml
test_transform: test_transform/lines.yaml
mapping:
- _target_: text_recognizer.data.mappings.EmnistMapping
+ _target_: text_recognizer.data.tokenizer.Tokenizer
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index c46714c..64c3964 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -13,6 +13,6 @@ target_transform:
_target_: text_recognizer.data.transforms.pad.Pad
max_len: 682
pad_index: 3
-mapping:
- _target_: text_recognizer.data.mappings.EmnistMapping
+tokenizer:
+ _target_: text_recognizer.data.tokenizer.Tokenizer
extra_symbols: ["\n"]
diff --git a/training/conf/datamodule/iam_lines.yaml b/training/conf/datamodule/iam_lines.yaml
index 4f1f1b8..f84116d 100644
--- a/training/conf/datamodule/iam_lines.yaml
+++ b/training/conf/datamodule/iam_lines.yaml
@@ -9,5 +9,5 @@ transform:
test_transform:
_target_: text_recognizer.data.stems.line.IamLinesStem
augment: false
-mapping:
- _target_: text_recognizer.data.mappings.EmnistMapping
+tokenizer:
+ _target_: text_recognizer.data.tokenizer.Tokenizer
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 60ff1bf..cdac387 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -84,8 +84,9 @@ network:
decoder:
_target_: text_recognizer.networks.transformer.Decoder
depth: 6
+ dim: *hidden_dim
block:
- _target_: text_recognizer.networks.transformer.DecoderBlock
+ _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml
deleted file mode 100644
index 8cbd55d..0000000
--- a/training/conf/mapping/characters.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-_target_: text_recognizer.data.mappings.EmnistMapping
-extra_symbols: [ "\n" ]
diff --git a/training/conf/tokenizer/default.yaml b/training/conf/tokenizer/default.yaml
new file mode 100644
index 0000000..2b1a8c9
--- /dev/null
+++ b/training/conf/tokenizer/default.yaml
@@ -0,0 +1,2 @@
+_target_: text_recognizer.data.tokenizer.Tokenizer
+extra_symbols: ["\n"]
diff --git a/training/run.py b/training/run.py
index 68cedc7..99059d6 100644
--- a/training/run.py
+++ b/training/run.py
@@ -14,7 +14,6 @@ from pytorch_lightning import (
from pytorch_lightning.loggers import LightningLoggerBase
from torch import nn
from torchinfo import summary
-
import utils
@@ -39,7 +38,7 @@ def run(config: DictConfig) -> Optional[float]:
model: LightningModule = hydra.utils.instantiate(
config.model,
network=network,
- mapping=datamodule.mapping,
+ tokenizer=datamodule.tokenizer,
loss_fn=loss_fn,
optimizer_config=config.optimizer,
lr_scheduler_config=config.lr_scheduler,