summaryrefslogtreecommitdiff
path: root/src/training/util.py
blob: 132b2dc74ae1054fa62363d9586c2d2d1ca73983 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Utility functions for training neural networks."""


class RunningAverage:
    """Maintains a running average."""

    def __init__(self) -> None:
        """Initializes the parameters."""
        self.steps = 0
        self.total = 0

    def update(self, val: float) -> None:
        """Updates the parameters."""
        self.total += val
        self.steps += 1

    def __call__(self) -> float:
        """Computes the running average."""
        return self.total / float(self.steps)