summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/config.yaml1
-rw-r--r--training/run.py10
2 files changed, 10 insertions, 1 deletions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index a95307f..e57a8a8 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -4,6 +4,7 @@ defaults:
- _self_
- callbacks: default
- criterion: cross_entropy
+ - decoder: greedy
- datamodule: iam_extended_paragraphs
- hydra: default
- logger: wandb
diff --git a/training/run.py b/training/run.py
index 429b1a2..288a1ef 100644
--- a/training/run.py
+++ b/training/run.py
@@ -1,5 +1,5 @@
"""Script to run experiments."""
-from typing import List, Optional, Type
+from typing import Callable, List, Optional, Type
import hydra
from loguru import logger as log
@@ -34,11 +34,19 @@ def run(config: DictConfig) -> Optional[float]:
log.info(f"Instantiating criterion <{config.criterion._target_}>")
loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
+ log.info(f"Instantiating decoder <{config.criterion._target_}>")
+ decoder: Type[Callable] = hydra.utils.instantiate(
+ config.decoder,
+ network=network,
+ tokenizer=datamodule.tokenizer,
+ )
+
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
config.model,
network=network,
tokenizer=datamodule.tokenizer,
+ decoder=decoder,
loss_fn=loss_fn,
optimizer_config=config.optimizer,
lr_scheduler_config=config.lr_scheduler,