blob: 7cf1b455859940d6fd09e030ae13e46a6fa19a79 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
"""Utility functions for training neural networks."""
from typing import Dict, Optional
from loguru import logger
def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
"""Logging of val metrics to file/terminal."""
log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
class RunningAverage:
"""Maintains a running average."""
def __init__(self) -> None:
"""Initializes the parameters."""
self.steps = 0
self.total = 0
def update(self, val: float) -> None:
"""Updates the parameters."""
self.total += val
self.steps += 1
def __call__(self) -> float:
"""Computes the running average."""
return self.total / float(self.steps)
|