diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index 8ae994a..40a25da 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -38,6 +38,7 @@ class Trainer: callbacks: List[Type[Callback]], transformer_model: bool = False, max_norm: float = 0.0, + freeze_backbone: Optional[int] = None, ) -> None: """Initialization of the Trainer. @@ -46,12 +47,15 @@ class Trainer: callbacks (CallbackList): List of callbacks to be called. transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. max_norm (float): Max norm for gradient clipping. Defaults to 0.0. + freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training + Transformers. Default is None. """ # Training arguments. self.start_epoch = 1 self.max_epochs = max_epochs self.callbacks = callbacks + self.freeze_backbone = freeze_backbone # Flag for setting callbacks. self.callbacks_configured = False @@ -115,7 +119,14 @@ class Trainer: # Forward pass. # Get the network prediction. if self.transformer_model: - output = self.model.network.forward(data, targets[:, :-1]) + if self.freeze_backbone is not None and batch < self.freeze_backbone: + with torch.no_grad(): + image_features = self.model.network.extract_image_features(data) + output = self.model.network.decode_image_features( + image_features, targets[:, :-1] + ) + else: + output = self.model.network.forward(data, targets[:, :-1]) output = rearrange(output, "b t v -> (b t) v") targets = rearrange(targets[:, 1:], "b t -> (b t)").long() else: |