From 8c4a0c2603975cfc63f4e4019386e001387c42c9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 Oct 2022 22:08:38 +0200 Subject: Add greedy decoder --- training/conf/config.yaml | 1 + training/run.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/conf/config.yaml b/training/conf/config.yaml index a95307f..e57a8a8 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -4,6 +4,7 @@ defaults: - _self_ - callbacks: default - criterion: cross_entropy + - decoder: greedy - datamodule: iam_extended_paragraphs - hydra: default - logger: wandb diff --git a/training/run.py b/training/run.py index 429b1a2..288a1ef 100644 --- a/training/run.py +++ b/training/run.py @@ -1,5 +1,5 @@ """Script to run experiments.""" -from typing import List, Optional, Type +from typing import Callable, List, Optional, Type import hydra from loguru import logger as log @@ -34,11 +34,19 @@ def run(config: DictConfig) -> Optional[float]: log.info(f"Instantiating criterion <{config.criterion._target_}>") loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) + log.info(f"Instantiating decoder <{config.criterion._target_}>") + decoder: Type[Callable] = hydra.utils.instantiate( + config.decoder, + network=network, + tokenizer=datamodule.tokenizer, + ) + log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( config.model, network=network, tokenizer=datamodule.tokenizer, + decoder=decoder, loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, -- cgit v1.2.3-70-g09d2