From 4d7713746eb936832e84852e90292936b933e87d Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 22 Oct 2020 22:45:58 +0200 Subject: Transfomer added, many other changes. --- src/training/trainer/callbacks/checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'src/training/trainer/callbacks/checkpoint.py') diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py index 6fe06d3..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/src/training/trainer/callbacks/checkpoint.py @@ -21,7 +21,7 @@ class Checkpoint(Callback): def __init__( self, - checkpoint_path: Path, + checkpoint_path: Union[str, Path], monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0, @@ -29,14 +29,14 @@ class Checkpoint(Callback): """Monitors a quantity that will allow us to determine the best model weights. Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. monitor (str): Name of the quantity to monitor. Defaults to "accuracy". mode (str): Description of parameter `mode`. Defaults to "auto". min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. """ super().__init__() - self.checkpoint_path = checkpoint_path + self.checkpoint_path = Path(checkpoint_path) self.monitor = monitor self.mode = mode self.min_delta = torch.tensor(min_delta) -- cgit v1.2.3-70-g09d2