From 31d58f2108165802d26eb1c1bdb9e5f052b4dd26 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Apr 2021 22:31:54 +0200 Subject: Fix network args --- text_recognizer/networks/__init__.py | 3 +++ text_recognizer/networks/image_transformer.py | 5 ++-- training/experiments/image_transformer.yaml | 5 +--- training/run_experiment.py | 36 +++++++++++++-------------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index e69de29..4dcaf2e 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -0,0 +1,3 @@ +"""Network modules""" +from .image_transformer import ImageTransformer + diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 85a84d2..edebca9 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -5,7 +5,7 @@ i.e. feature maps. A 2d positional encoding is applied to the feature maps for spatial information. The resulting feature are then set to a transformer decoder together with the target tokens. -TODO: Local attention for transformer.j +TODO: Local attention for lower layer in attention. """ import importlib @@ -39,7 +39,7 @@ class ImageTransformer(nn.Module): num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, - expansion_dim: int = 4, + expansion_dim: int = 1024, dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: @@ -109,6 +109,7 @@ class ImageTransformer(nn.Module): def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" + # TODO: Fix me!!! if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() return mapping, inverse_mapping diff --git a/training/experiments/image_transformer.yaml b/training/experiments/image_transformer.yaml index 012a19b..9e8f9fc 100644 --- a/training/experiments/image_transformer.yaml +++ b/training/experiments/image_transformer.yaml @@ -1,12 +1,9 @@ network: type: ImageTransformer args: - input_shape: None - output_shape: None encoder: type: None args: None - mapping: sentence_piece num_decoder_layers: 4 hidden_dim: 256 num_heads: 4 @@ -60,7 +57,7 @@ callbacks: trainer: args: stochastic_weight_avg: true - auto_scale_batch_size: power + auto_scale_batch_size: binsearch gradient_clip_val: 0 fast_dev_run: false gpus: 1 diff --git a/training/run_experiment.py b/training/run_experiment.py index 0a67bfa..ea9f512 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -6,7 +6,6 @@ from typing import Dict, List, NamedTuple, Optional, Union, Type import click from loguru import logger -import numpy as np from omegaconf import OmegaConf import pytorch_lightning as pl import torch @@ -23,10 +22,10 @@ EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: """Configure the loguru logger for output to terminal and disk.""" - def _get_level(verbose: int) -> int: + def _get_level(verbose: int) -> str: """Sets the logger level.""" - levels = {0: 40, 1: 20, 2: 10} - verbose = verbose if verbose <= 2 else 2 + levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"} + verbose = min(verbose, 2) return levels[verbose] # Have to remove default logger to get tqdm to work properly. @@ -50,7 +49,7 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_pl_callbacks( +def _configure_callbacks( args: List[Union[OmegaConf, NamedTuple]] ) -> List[Type[pl.callbacks.Callback]]: """Configures PyTorch Lightning callbacks.""" @@ -60,13 +59,16 @@ def _configure_pl_callbacks( return pl_callbacks -def _configure_wandb_callback( - network: Type[nn.Module], args: Dict +def _configure_logger( + network: Type[nn.Module], args: Dict, use_wandb: bool ) -> pl.loggers.WandbLogger: - """Configures wandb logger.""" - pl_logger = pl.loggers.WandbLogger() - pl_logger.watch(network) - pl_logger.log_hyperparams(vars(args)) + """Configures PyTorch Lightning logger.""" + if use_wandb: + pl_logger = pl.loggers.WandbLogger() + pl_logger.watch(network) + pl_logger.log_hyperparams(vars(args)) + else: + pl_logger = pl.logger.TensorBoardLogger("training/logs") return pl_logger @@ -106,15 +108,11 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None # Initialize data object and network. data_module = data_module_class(**config.data.args) - network = network_class(**config.network.args) + network = network_class(**data_module.config(), **config.network.args) # Load callback and logger - callbacks = _configure_pl_callbacks(config.callbacks) - pl_logger = ( - _configure_wandb_callback(network, config.network.args) - if use_wandb - else pl.logger.TensorBoardLogger("training/logs") - ) + callbacks = _configure_callbacks(config.callbacks) + pl_logger = _configure_logger(network, config, use_wandb) # Checkpoint if config.load_checkpoint is not None: @@ -125,7 +123,7 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None config.load_checkpoint, network=network, **config.model.args ) else: - lit_model = lit_model_class(**config.model.args) + lit_model = lit_model_class(network=network, **config.model.args) trainer = pl.Trainer( **config.trainer, -- cgit v1.2.3-70-g09d2