summaryrefslogtreecommitdiff
path: root/training/trainer/util.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /training/trainer/util.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'training/trainer/util.py')
-rw-r--r--training/trainer/util.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/training/trainer/util.py b/training/trainer/util.py
new file mode 100644
index 0000000..7cf1b45
--- /dev/null
+++ b/training/trainer/util.py
@@ -0,0 +1,28 @@
+"""Utility functions for training neural networks."""
+from typing import Dict, Optional
+
+from loguru import logger
+
+
+def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ """Logging of val metrics to file/terminal."""
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
+
+
+class RunningAverage:
+ """Maintains a running average."""
+
+ def __init__(self) -> None:
+ """Initializes the parameters."""
+ self.steps = 0
+ self.total = 0
+
+ def update(self, val: float) -> None:
+ """Updates the parameters."""
+ self.total += val
+ self.steps += 1
+
+ def __call__(self) -> float:
+ """Computes the running average."""
+ return self.total / float(self.steps)