summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
commit25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch)
tree526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/training/trainer/train.py
parent5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff)
Segmentation working!
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index 8ae994a..40a25da 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -38,6 +38,7 @@ class Trainer:
callbacks: List[Type[Callback]],
transformer_model: bool = False,
max_norm: float = 0.0,
+ freeze_backbone: Optional[int] = None,
) -> None:
"""Initialization of the Trainer.
@@ -46,12 +47,15 @@ class Trainer:
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
+ freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
+ Transformers. Default is None.
"""
# Training arguments.
self.start_epoch = 1
self.max_epochs = max_epochs
self.callbacks = callbacks
+ self.freeze_backbone = freeze_backbone
# Flag for setting callbacks.
self.callbacks_configured = False
@@ -115,7 +119,14 @@ class Trainer:
# Forward pass.
# Get the network prediction.
if self.transformer_model:
- output = self.model.network.forward(data, targets[:, :-1])
+ if self.freeze_backbone is not None and batch < self.freeze_backbone:
+ with torch.no_grad():
+ image_features = self.model.network.extract_image_features(data)
+ output = self.model.network.decode_image_features(
+ image_features, targets[:, :-1]
+ )
+ else:
+ output = self.model.network.forward(data, targets[:, :-1])
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else: