diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 2 | ||||
-rw-r--r-- | training/run.py | 10 | ||||
-rw-r--r-- | training/utils.py | 2 |
3 files changed, 7 insertions, 7 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 451b0d5..6379cc0 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -94,7 +94,7 @@ class LogTextPredictions(Callback): super().__init__() def _log_predictions( - stage: str, trainer: Trainer, pl_module: LightningModule + self, stage: str, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs the predicted text contained in the images.""" if not self.ready: diff --git a/training/run.py b/training/run.py index f745d61..d88a8f6 100644 --- a/training/run.py +++ b/training/run.py @@ -2,7 +2,7 @@ from typing import List, Optional, Type import hydra -from loguru import logger as log +import loguru.logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, @@ -33,11 +33,11 @@ def run(config: DictConfig) -> Optional[float]: log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( - config.model, + **config.model, network=network, - criterion=config.criterion, - optimizer=config.optimizer, - lr_scheduler=config.lr_scheduler, + criterion_config=config.criterion, + optimizer_config=config.optimizer, + lr_scheduler_config=config.lr_scheduler, _recursive_=False, ) diff --git a/training/utils.py b/training/utils.py index ef74f61..564b9bb 100644 --- a/training/utils.py +++ b/training/utils.py @@ -3,7 +3,7 @@ from typing import Any, List, Type import warnings import hydra -from loguru import logger as log +import loguru.logger as log from omegaconf import DictConfig, OmegaConf from pytorch_lightning import ( Callback, |