diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
commit | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch) | |
tree | 1b5fc0d06952e13727e85c4f973a26d277068453 /src/training/trainer/callbacks/checkpoint.py | |
parent | e181195a699d7fa237f256d90ab4dedffc03d405 (diff) |
new updates
Diffstat (limited to 'src/training/trainer/callbacks/checkpoint.py')
-rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py index 6fe06d3..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/src/training/trainer/callbacks/checkpoint.py @@ -21,7 +21,7 @@ class Checkpoint(Callback): def __init__( self, - checkpoint_path: Path, + checkpoint_path: Union[str, Path], monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0, @@ -29,14 +29,14 @@ class Checkpoint(Callback): """Monitors a quantity that will allow us to determine the best model weights. Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. monitor (str): Name of the quantity to monitor. Defaults to "accuracy". mode (str): Description of parameter `mode`. Defaults to "auto". min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. """ super().__init__() - self.checkpoint_path = checkpoint_path + self.checkpoint_path = Path(checkpoint_path) self.monitor = monitor self.mode = mode self.min_delta = torch.tensor(min_delta) |