summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
commite181195a699d7fa237f256d90ab4dedffc03d405 (patch)
tree6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/training
parent3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff)
Minor bug fixes etc.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/iam_line_ctc_experiment.yml94
-rw-r--r--src/training/experiments/line_ctc_experiment.yml97
-rw-r--r--src/training/experiments/sample_experiment.yml13
-rw-r--r--src/training/run_experiment.py79
-rw-r--r--src/training/trainer/callbacks/__init__.py14
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py121
-rw-r--r--src/training/trainer/callbacks/progress_bar.py1
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py1
-rw-r--r--src/training/trainer/train.py22
9 files changed, 135 insertions, 307 deletions
diff --git a/src/training/experiments/iam_line_ctc_experiment.yml b/src/training/experiments/iam_line_ctc_experiment.yml
deleted file mode 100644
index 141c74e..0000000
--- a/src/training/experiments/iam_line_ctc_experiment.yml
+++ /dev/null
@@ -1,94 +0,0 @@
-experiment_group: Sample Experiments
-experiments:
- - train_args:
- batch_size: 24
- max_epochs: 128
- dataset:
- type: IamLinesDataset
- args:
- subsample_fraction: null
- transform: null
- target_transform: null
- train_args:
- num_workers: 6
- train_fraction: 0.85
- model: LineCTCModel
- metrics: [cer, wer]
- network:
- type: LineRecurrentNetwork
- args:
- # encoder: ResidualNetworkEncoder
- # encoder_args:
- # in_channels: 1
- # num_classes: 80
- # depths: [2, 2]
- # block_sizes: [128, 128]
- # activation: SELU
- # stn: false
- encoder: WideResidualNetwork
- encoder_args:
- in_channels: 1
- num_classes: 80
- depth: 16
- num_layers: 4
- width_factor: 2
- dropout_rate: 0.2
- activation: selu
- use_decoder: false
- flatten: true
- input_size: 256
- hidden_size: 128
- num_layers: 2
- num_classes: 80
- patch_size: [28, 14]
- stride: [1, 5]
- criterion:
- type: CTCLoss
- args:
- blank: 79
- optimizer:
- type: AdamW
- args:
- lr: 1.e-03
- betas: [0.9, 0.999]
- eps: 1.e-08
- weight_decay: false
- amsgrad: false
- # lr_scheduler:
- # type: OneCycleLR
- # args:
- # max_lr: 1.e-02
- # epochs: null
- # anneal_strategy: linear
- lr_scheduler:
- type: CosineAnnealingLR
- args:
- T_max: null
- swa_args:
- start: 75
- lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR]
- callback_args:
- Checkpoint:
- monitor: val_loss
- mode: min
- ProgressBar:
- epochs: null
- # log_batch_frequency: 100
- # EarlyStopping:
- # monitor: val_loss
- # min_delta: 0.0
- # patience: 7
- # mode: min
- WandbCallback:
- log_batch_frequency: 10
- WandbImageLogger:
- num_examples: 6
- # OneCycleLR:
- # null
- SWA:
- null
- verbosity: 1 # 0, 1, 2
- resume_experiment: null
- test: true
- test_metric: test_cer
diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml
index c21c6a2..432d1cc 100644
--- a/src/training/experiments/line_ctc_experiment.yml
+++ b/src/training/experiments/line_ctc_experiment.yml
@@ -1,55 +1,46 @@
-experiment_group: Sample Experiments
+experiment_group: Lines Experiments
experiments:
- train_args:
- batch_size: 64
- max_epochs: 32
+ batch_size: 42
+ max_epochs: &max_epochs 32
dataset:
- type: EmnistLinesDataset
+ type: IamLinesDataset
args:
- subsample_fraction: 0.33
- max_length: 34
- min_overlap: 0
- max_overlap: 0.33
- num_samples: 10000
- seed: 4711
- blank: true
+ subsample_fraction: null
+ transform: null
+ target_transform: null
train_args:
- num_workers: 6
+ num_workers: 8
train_fraction: 0.85
model: LineCTCModel
metrics: [cer, wer]
network:
type: LineRecurrentNetwork
args:
- # encoder: ResidualNetworkEncoder
- # encoder_args:
- # in_channels: 1
- # num_classes: 81
- # depths: [2, 2]
- # block_sizes: [64, 128]
- # activation: SELU
- # stn: false
- encoder: WideResidualNetwork
- encoder_args:
+ backbone: ResidualNetwork
+ backbone_args:
in_channels: 1
- num_classes: 81
- depth: 16
- num_layers: 4
- width_factor: 2
- dropout_rate: 0.2
+ num_classes: 64 # Embedding
+ depths: [2,2]
+ block_sizes: [32,64]
activation: selu
- use_decoder: false
- flatten: true
- input_size: 256
- hidden_size: 128
+ stn: false
+ # encoder: ResidualNetwork
+ # encoder_args:
+ # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt
+ # freeze: false
+ flatten: false
+ input_size: 64
+ hidden_size: 64
+ bidirectional: true
num_layers: 2
- num_classes: 81
- patch_size: [28, 14]
- stride: [1, 5]
+ num_classes: 80
+ patch_size: [28, 18]
+ stride: [1, 4]
criterion:
type: CTCLoss
args:
- blank: 80
+ blank: 79
optimizer:
type: AdamW
args:
@@ -58,40 +49,42 @@ experiments:
eps: 1.e-08
weight_decay: 5.e-4
amsgrad: false
- # lr_scheduler:
- # type: OneCycleLR
- # args:
- # max_lr: 1.e-03
- # epochs: null
- # anneal_strategy: linear
lr_scheduler:
- type: CosineAnnealingLR
+ type: OneCycleLR
args:
- T_max: null
+ max_lr: 1.e-02
+ epochs: *max_epochs
+ anneal_strategy: cos
+ pct_start: 0.475
+ cycle_momentum: true
+ base_momentum: 0.85
+ max_momentum: 0.9
+ div_factor: 10
+ final_div_factor: 10000
+ interval: step
+ # lr_scheduler:
+ # type: CosineAnnealingLR
+ # args:
+ # T_max: *max_epochs
swa_args:
- start: 4
+ start: 24
lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR]
+ callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping]
callback_args:
Checkpoint:
monitor: val_loss
mode: min
ProgressBar:
- epochs: null
- log_batch_frequency: 100
+ epochs: *max_epochs
# EarlyStopping:
# monitor: val_loss
# min_delta: 0.0
- # patience: 5
+ # patience: 10
# mode: min
WandbCallback:
log_batch_frequency: 10
WandbImageLogger:
num_examples: 6
- # OneCycleLR:
- # null
- SWA:
- null
verbosity: 1 # 0, 1, 2
resume_experiment: null
test: true
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 17e220e..8664a15 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -2,7 +2,7 @@ experiment_group: Sample Experiments
experiments:
- train_args:
batch_size: 256
- max_epochs: 32
+ max_epochs: &max_epochs 32
dataset:
type: EmnistDataset
args:
@@ -66,16 +66,17 @@ experiments:
# type: OneCycleLR
# args:
# max_lr: 1.e-03
- # epochs: null
+ # epochs: *max_epochs
# anneal_strategy: linear
lr_scheduler:
type: CosineAnnealingLR
args:
- T_max: null
+ T_max: *max_epochs
+ interval: epoch
swa_args:
start: 2
lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping, SWA] # OneCycleLR]
+ callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping]
callback_args:
Checkpoint:
monitor: val_accuracy
@@ -92,10 +93,6 @@ experiments:
WandbImageLogger:
num_examples: 4
use_transpose: true
- # OneCycleLR:
- # null
- SWA:
- null
verbosity: 0 # 0, 1, 2
resume_experiment: null
test: true
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 286b0c6..a347d9f 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Tuple, Type
import click
from loguru import logger
+import numpy as np
import torch
from tqdm import tqdm
from training.gpu_manager import GPUManager
@@ -20,11 +21,12 @@ import yaml
from text_recognizer.models import Model
+from text_recognizer.networks import losses
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-
+CUSTOM_LOSSES = ["EmbeddingLoss"]
DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16}
@@ -69,21 +71,6 @@ def create_experiment_dir(experiment_config: Dict) -> Path:
return experiment_dir, log_dir, model_dir
-def check_args(args: Dict, train_args: Dict) -> Dict:
- """Checks that the arguments are not None."""
- args = args or {}
-
- # I just want to set total epochs in train args, instead of changing all parameter.
- if "epochs" in args and args["epochs"] is None:
- args["epochs"] = train_args["max_epochs"]
-
- # For CosineAnnealingLR.
- if "T_max" in args and args["T_max"] is None:
- args["T_max"] = train_args["max_epochs"]
-
- return args or {}
-
-
def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
# Import the data loader arguments.
@@ -115,8 +102,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
network_args = experiment_config["network"].get("args", {})
# Criterion
- criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
+ criterion_ = getattr(losses, experiment_config["criterion"]["type"])
+ criterion_args = experiment_config["criterion"].get("args", {})
+ else:
+ criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
+ criterion_args = experiment_config["criterion"].get("args", {})
# Optimizers
optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
@@ -129,13 +120,11 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
lr_scheduler_ = getattr(
torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
)
- lr_scheduler_args = check_args(
- experiment_config["lr_scheduler"].get("args", {}), train_args
- )
+ lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {}
# SWA scheduler.
if "swa_args" in experiment_config:
- swa_args = check_args(experiment_config.get("swa_args", {}), train_args)
+ swa_args = experiment_config.get("swa_args", {}) or {}
else:
swa_args = None
@@ -159,19 +148,15 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList:
"""Configure a callback list for trainer."""
- train_args = experiment_config.get("train_args", {})
-
if "Checkpoint" in experiment_config["callback_args"]:
experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir
- # Callbacks
+ # Initializes callbacks.
callback_modules = importlib.import_module("training.trainer.callbacks")
- callbacks = [
- getattr(callback_modules, callback)(
- **check_args(experiment_config["callback_args"][callback], train_args)
- )
- for callback in experiment_config["callbacks"]
- ]
+ callbacks = []
+ for callback in experiment_config["callbacks"]:
+ args = experiment_config["callback_args"][callback] or {}
+ callbacks.append(getattr(callback_modules, callback)(**args))
return callbacks
@@ -207,11 +192,35 @@ def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) ->
model.load_checkpoint(checkpoint_path)
+def evaluate_embedding(model: Type[Model]) -> Dict:
+ """Evaluates the embedding space."""
+ from pytorch_metric_learning import testers
+ from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
+
+ accuracy_calculator = AccuracyCalculator(
+ include=("mean_average_precision_at_r",), k=10
+ )
+
+ def get_all_embeddings(model: Type[Model]) -> Tuple:
+ tester = testers.BaseTester()
+ return tester.get_all_embeddings(model.test_dataset, model.network)
+
+ embeddings, labels = get_all_embeddings(model)
+ logger.info("Computing embedding accuracy")
+ accuracies = accuracy_calculator.get_accuracy(
+ embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True
+ )
+ logger.info(
+ f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}"
+ )
+ return accuracies
+
+
def run_experiment(
experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
) -> None:
"""Runs an experiment."""
- logger.info(f"Experiment config: {json.dumps(experiment_config, indent=2)}")
+ logger.info(f"Experiment config: {json.dumps(experiment_config)}")
# Create new experiment.
experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)
@@ -272,7 +281,11 @@ def run_experiment(
model.load_from_checkpoint(model_dir / "best.pt")
logger.info("Running inference on test set.")
- score = trainer.test(model)
+ if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
+ logger.info("Evaluating embedding.")
+ score = evaluate_embedding(model)
+ else:
+ score = trainer.test(model)
logger.info(f"Test set evaluation: {score}")
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index c81e4bf..e1bd858 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -3,12 +3,7 @@ from .base import Callback, CallbackList
from .checkpoint import Checkpoint
from .early_stopping import EarlyStopping
from .lr_schedulers import (
- CosineAnnealingLR,
- CyclicLR,
- MultiStepLR,
- OneCycleLR,
- ReduceLROnPlateau,
- StepLR,
+ LRScheduler,
SWA,
)
from .progress_bar import ProgressBar
@@ -18,15 +13,10 @@ __all__ = [
"Callback",
"CallbackList",
"Checkpoint",
- "CosineAnnealingLR",
"EarlyStopping",
+ "LRScheduler",
"WandbCallback",
"WandbImageLogger",
- "CyclicLR",
- "MultiStepLR",
- "OneCycleLR",
"ProgressBar",
- "ReduceLROnPlateau",
- "StepLR",
"SWA",
]
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index bb41d2d..907e292 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback
from text_recognizer.models import Model
-class StepLR(Callback):
- """Callback for StepLR."""
+class LRScheduler(Callback):
+ """Generic learning rate scheduler callback."""
def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class MultiStepLR(Callback):
- """Callback for MultiStepLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class ReduceLROnPlateau(Callback):
- """Callback for ReduceLROnPlateau."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
super().__init__()
- self.lr_scheduler = None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and lr scheduler."""
self.model = model
- self.lr_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
- val_loss = logs["val_loss"]
- self.lr_scheduler.step(val_loss)
-
-
-class CyclicLR(Callback):
- """Callback for CyclicLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class OneCycleLR(Callback):
- """Callback for OneCycleLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
+ if self.interval == "epoch":
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class CosineAnnealingLR(Callback):
- """Callback for Cosine Annealing."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
+ if self.interval == "step":
+ self.lr_scheduler.step()
class SWA(Callback):
@@ -122,21 +36,32 @@ class SWA(Callback):
def __init__(self) -> None:
"""Initializes the callback."""
super().__init__()
+ self.lr_scheduler = None
+ self.interval = None
self.swa_scheduler = None
+ self.swa_start = None
+ self.current_epoch = 1
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and lr scheduler."""
self.model = model
- self.swa_start = self.model.swa_start
- self.swa_scheduler = self.model.lr_scheduler
- self.lr_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
+ self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
+ self.swa_start = self.model.swa_scheduler["swa_start"]
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
if epoch > self.swa_start:
self.model.swa_network.update_parameters(self.model.network)
self.swa_scheduler.step()
- else:
+ elif self.interval == "epoch":
+ self.lr_scheduler.step()
+ self.current_epoch = epoch
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if self.current_epoch < self.swa_start and self.interval == "step":
self.lr_scheduler.step()
def on_fit_end(self) -> None:
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
index 7829fa0..6c4305a 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -11,6 +11,7 @@ class ProgressBar(Callback):
def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
"""Initializes the tqdm callback."""
self.epochs = epochs
+ print(epochs, type(epochs))
self.log_batch_frequency = log_batch_frequency
self.progress_bar = None
self.val_metrics = {}
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 6643a44..d2df4d7 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -32,6 +32,7 @@ class WandbCallback(Callback):
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs training metrics."""
if logs is not None:
+ logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
self._on_batch_end(batch, logs)
def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index b240157..bd6a491 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -9,7 +9,7 @@ import numpy as np
import torch
from torch import Tensor
from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
from training.trainer.util import log_val_metric, RunningAverage
import wandb
@@ -47,8 +47,14 @@ class Trainer:
self.model = None
def _configure_callbacks(self) -> None:
+ """Instantiate the CallbackList."""
if not self.callbacks_configured:
- # Instantiate a CallbackList.
+ # If learning rate schedulers are present, they need to be added to the callbacks.
+ if self.model.swa_scheduler is not None:
+ self.callbacks.append(SWA())
+ elif self.model.lr_scheduler is not None:
+ self.callbacks.append(LRScheduler())
+
self.callbacks = CallbackList(self.model, self.callbacks)
def compute_metrics(
@@ -91,7 +97,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.network(data)
+ output = self.model.forward(data)
# Compute the loss.
loss = self.model.loss_fn(output, targets)
@@ -130,7 +136,6 @@ class Trainer:
batch: int,
samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
- use_swa: bool = False,
) -> Dict:
"""Performs the validation step."""
# Pass the tensor to the device for computation.
@@ -143,10 +148,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
- if use_swa and self.model.swa_network is None:
- output = self.model.swa_network(data)
- else:
- output = self.model.network(data)
+ output = self.model.forward(data)
# Compute the loss.
loss = self.model.loss_fn(output, targets)
@@ -238,7 +240,7 @@ class Trainer:
self.model.eval()
# Check if SWA network is available.
- use_swa = True if self.model.swa_network is not None else False
+ self.model.use_swa_model()
# Running average for the loss.
loss_avg = RunningAverage()
@@ -247,7 +249,7 @@ class Trainer:
summary = []
for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples, loss_avg, use_swa)
+ metrics = self.validation_step(batch, samples, loss_avg)
summary.append(metrics)
# Compute mean of all test metrics.