diff options
Diffstat (limited to 'src/training')
-rw-r--r-- | src/training/callbacks/wandb_callbacks.py | 4 | ||||
-rw-r--r-- | src/training/experiments/sample_experiment.yml | 96 | ||||
-rw-r--r-- | src/training/population_based_training/__init__.py | 1 | ||||
-rw-r--r-- | src/training/population_based_training/population_based_training.py | 1 | ||||
-rw-r--r-- | src/training/prepare_experiments.py | 13 | ||||
-rw-r--r-- | src/training/run_experiment.py | 22 | ||||
-rw-r--r-- | src/training/train.py | 4 |
7 files changed, 105 insertions, 36 deletions
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py index f64cbe1..6ada6df 100644 --- a/src/training/callbacks/wandb_callbacks.py +++ b/src/training/callbacks/wandb_callbacks.py @@ -72,7 +72,7 @@ class WandbImageLogger(Callback): def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model - data_loader = self.model.data_loaders("val") + data_loader = self.model.data_loaders["val"] if self.example_indices is None: self.example_indices = np.random.randint( 0, len(data_loader.dataset.data), self.num_examples @@ -86,7 +86,7 @@ class WandbImageLogger(Callback): for i, image in enumerate(self.val_images): image = self.transforms(image) pred, conf = self.model.predict_on_image(image) - ground_truth = self.model._mapping[self.val_targets[i]] + ground_truth = self.model.mapper(int(self.val_targets[i])) caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 70edb63..57198f1 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -1,28 +1,30 @@ experiment_group: Sample Experiments experiments: - - dataloader: EmnistDataLoaders - data_loader_args: - splits: [train, val] + - dataset: EmnistDataset + dataset_args: sample_to_balance: true subsample_fraction: null transform: null target_transform: null + seed: 4711 + data_loader_args: + splits: [train, val] batch_size: 256 shuffle: true num_workers: 8 cuda: true - seed: 4711 model: CharacterModel metrics: [accuracy] - # network: MLP - # network_args: - # input_size: 784 - # output_size: 62 - # num_layers: 3 - network: LeNet + network: MLP network_args: - input_size: [28, 28] + input_size: 784 output_size: 62 + num_layers: 3 + activation_fn: GELU + # network: LeNet + # network_args: + # output_size: 62 + # activation_fn: GELU train_args: batch_size: 256 epochs: 16 @@ -66,5 +68,75 @@ experiments: num_examples: 4 OneCycleLR: null - verbosity: 1 # 0, 1, 2 + verbosity: 2 # 0, 1, 2 resume_experiment: null + # - dataset: EmnistDataset + # dataset_args: + # sample_to_balance: true + # subsample_fraction: null + # transform: null + # target_transform: null + # seed: 4711 + # data_loader_args: + # splits: [train, val] + # batch_size: 256 + # shuffle: true + # num_workers: 8 + # cuda: true + # model: CharacterModel + # metrics: [accuracy] + # # network: MLP + # # network_args: + # # input_size: 784 + # # output_size: 62 + # # num_layers: 3 + # # activation_fn: GELU + # network: LeNet + # network_args: + # output_size: 62 + # activation_fn: GELU + # train_args: + # batch_size: 256 + # epochs: 16 + # criterion: CrossEntropyLoss + # criterion_args: + # weight: null + # ignore_index: -100 + # reduction: mean + # # optimizer: RMSprop + # # optimizer_args: + # # lr: 1.e-3 + # # alpha: 0.9 + # # eps: 1.e-7 + # # momentum: 0 + # # weight_decay: 0 + # # centered: false + # optimizer: AdamW + # optimizer_args: + # lr: 1.e-2 + # betas: [0.9, 0.999] + # eps: 1.e-08 + # weight_decay: 0 + # amsgrad: false + # # lr_scheduler: null + # lr_scheduler: OneCycleLR + # lr_scheduler_args: + # max_lr: 1.e-3 + # epochs: 16 + # callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] + # callback_args: + # Checkpoint: + # monitor: val_accuracy + # EarlyStopping: + # monitor: val_loss + # min_delta: 0.0 + # patience: 3 + # mode: min + # WandbCallback: + # log_batch_frequency: 10 + # WandbImageLogger: + # num_examples: 4 + # OneCycleLR: + # null + # verbosity: 2 # 0, 1, 2 + # resume_experiment: null diff --git a/src/training/population_based_training/__init__.py b/src/training/population_based_training/__init__.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/population_based_training/__init__.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/population_based_training/population_based_training.py b/src/training/population_based_training/population_based_training.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/population_based_training/population_based_training.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 5a665b3..97c0304 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -16,19 +16,8 @@ def run_experiments(experiments_filename: str) -> None: for index in range(num_experiments): experiment_config = experiments_config["experiments"][index] experiment_config["experiment_group"] = experiments_config["experiment_group"] - cmd = f"poetry run run-experiment --gpu=-1 --save --experiment_config={json.dumps(experiment_config)}" + cmd = f"python training/run_experiment.py --gpu=-1 --save --experiment_config='{json.dumps(experiment_config)}'" print(cmd) - run( - [ - "poetry", - "run", - "run-experiment", - "--gpu=-1", - "--save", - f"--experiment_config={json.dumps(experiment_config)}", - ], - check=True, - ) @click.command() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index c133ce5..d278dc2 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -58,12 +58,17 @@ def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: return experiment_dir +def check_args(args: Dict) -> Dict: + """Checks that the arguments are not None.""" + return args or {} + + def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" - # Import the data loader module and arguments. - datasets_module = importlib.import_module("text_recognizer.datasets") - data_loader_ = getattr(datasets_module, experiment_config["dataloader"]) + # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + data_loader_args["dataset"] = experiment_config["dataset"] + data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) # Import the model module and model arguments. models_module = importlib.import_module("text_recognizer.models") @@ -90,10 +95,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] # Callbacks callback_modules = importlib.import_module("training.callbacks") - callbacks = [] - for callback in experiment_config["callbacks"]: - args = experiment_config["callback_args"][callback] or {} - callbacks.append(getattr(callback_modules, callback)(**args)) + callbacks = [ + getattr(callback_modules, callback)( + **check_args(experiment_config["callback_args"][callback]) + ) + for callback in experiment_config["callbacks"] + ] # Learning rate scheduler if experiment_config["lr_scheduler"] is not None: @@ -106,7 +113,6 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] lr_scheduler_args = None model_args = { - "data_loader": data_loader_, "data_loader_args": data_loader_args, "metrics": metric_fns_, "network_fn": network_fn_, diff --git a/src/training/train.py b/src/training/train.py index 3334c2e..aaa0430 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -106,7 +106,7 @@ class Trainer: # Running average for the loss. loss_avg = RunningAverage() - data_loader = self.model.data_loaders("train") + data_loader = self.model.data_loaders["train"] with tqdm( total=len(data_loader), @@ -164,7 +164,7 @@ class Trainer: self.model.eval() # Running average for the loss. - data_loader = self.model.data_loaders("val") + data_loader = self.model.data_loaders["val"] # Running average for the loss. loss_avg = RunningAverage() |