summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/checkpoint.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/training/trainer/callbacks/checkpoint.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/training/trainer/callbacks/checkpoint.py')
-rw-r--r--src/training/trainer/callbacks/checkpoint.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py
index 6fe06d3..a54e0a9 100644
--- a/src/training/trainer/callbacks/checkpoint.py
+++ b/src/training/trainer/callbacks/checkpoint.py
@@ -21,7 +21,7 @@ class Checkpoint(Callback):
def __init__(
self,
- checkpoint_path: Path,
+ checkpoint_path: Union[str, Path],
monitor: str = "accuracy",
mode: str = "auto",
min_delta: float = 0.0,
@@ -29,14 +29,14 @@ class Checkpoint(Callback):
"""Monitors a quantity that will allow us to determine the best model weights.
Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
+ checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
mode (str): Description of parameter `mode`. Defaults to "auto".
min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
"""
super().__init__()
- self.checkpoint_path = checkpoint_path
+ self.checkpoint_path = Path(checkpoint_path)
self.monitor = monitor
self.mode = mode
self.min_delta = torch.tensor(min_delta)