diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-10-04 22:08:38 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-10-04 22:08:38 +0200 |
commit | 8c4a0c2603975cfc63f4e4019386e001387c42c9 (patch) | |
tree | 213dac666a8ac71e7a48608ee492b80572a23584 /training | |
parent | 3ec10be9141bf71fb10d699b31a66b4e5046973c (diff) |
Add greedy decoder
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/config.yaml | 1 | ||||
-rw-r--r-- | training/run.py | 10 |
2 files changed, 10 insertions, 1 deletions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml index a95307f..e57a8a8 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -4,6 +4,7 @@ defaults: - _self_ - callbacks: default - criterion: cross_entropy + - decoder: greedy - datamodule: iam_extended_paragraphs - hydra: default - logger: wandb diff --git a/training/run.py b/training/run.py index 429b1a2..288a1ef 100644 --- a/training/run.py +++ b/training/run.py @@ -1,5 +1,5 @@ """Script to run experiments.""" -from typing import List, Optional, Type +from typing import Callable, List, Optional, Type import hydra from loguru import logger as log @@ -34,11 +34,19 @@ def run(config: DictConfig) -> Optional[float]: log.info(f"Instantiating criterion <{config.criterion._target_}>") loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) + log.info(f"Instantiating decoder <{config.criterion._target_}>") + decoder: Type[Callable] = hydra.utils.instantiate( + config.decoder, + network=network, + tokenizer=datamodule.tokenizer, + ) + log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( config.model, network=network, tokenizer=datamodule.tokenizer, + decoder=decoder, loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, |