diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-01-29 21:16:40 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-01-29 21:16:40 +0100 |
commit | cf558e7146eabdf1e2c3435af31f4e87f4eb18bd (patch) | |
tree | 40332eb087b7d67f6b6fccef9f86f2db67854ef2 /text_recognizer/optimizers | |
parent | 5bbbec1f7021c2451099eb0d1e00a330275d1e16 (diff) |
Add laprop
Diffstat (limited to 'text_recognizer/optimizers')
-rw-r--r-- | text_recognizer/optimizers/__init__.py | 0 | ||||
-rw-r--r-- | text_recognizer/optimizers/laprop.py | 129 |
2 files changed, 129 insertions, 0 deletions
diff --git a/text_recognizer/optimizers/__init__.py b/text_recognizer/optimizers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/text_recognizer/optimizers/__init__.py diff --git a/text_recognizer/optimizers/laprop.py b/text_recognizer/optimizers/laprop.py new file mode 100644 index 0000000..9b4f4d4 --- /dev/null +++ b/text_recognizer/optimizers/laprop.py @@ -0,0 +1,129 @@ +"""https://github.com/Z-T-WANG/LaProp-Optimizer/blob/master/laprop.py""" +from torch.optim import Optimizer +import math +import torch + + +class LaProp(Optimizer): + def __init__( + self, + params, + lr=4e-4, + betas=(0.9, 0.999), + eps=1e-15, + weight_decay=0, + amsgrad=False, + centered=False, + ): + + self.steps_before_using_centered = 10 + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + centered=centered, + ) + super(LaProp, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group["amsgrad"] + centered = group["centered"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of learning rates + state["exp_avg_lr_1"] = 0.0 + state["exp_avg_lr_2"] = 0.0 + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if centered: + # Exponential moving average of gradient values as calculated by beta2 + state["exp_mean_avg_beta2"] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if centered: + exp_mean_avg_beta2 = state["exp_mean_avg_beta2"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + state["exp_avg_lr_1"] = ( + state["exp_avg_lr_1"] * beta1 + (1 - beta1) * group["lr"] + ) + state["exp_avg_lr_2"] = state["exp_avg_lr_2"] * beta2 + (1 - beta2) + + bias_correction1 = ( + state["exp_avg_lr_1"] / group["lr"] if group["lr"] != 0.0 else 1.0 + ) # 1 - beta1 ** state['step'] + step_size = 1 / bias_correction1 + + bias_correction2 = state["exp_avg_lr_2"] + + denom = exp_avg_sq + if centered: + exp_mean_avg_beta2.mul_(beta2).add_(1 - beta2, grad) + if state["step"] > self.steps_before_using_centered: + mean = exp_mean_avg_beta2**2 + denom = denom - mean + + if amsgrad: + if not ( + centered and state["step"] <= self.steps_before_using_centered + ): + # Maintains the maximum of all (centered) 2nd moment running avg. till now + torch.max(max_exp_avg_sq, denom, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq + + denom = denom.div(bias_correction2).sqrt_().add_(group["eps"]) + step_of_this_grad = grad / denom + exp_avg.mul_(beta1).add_((1 - beta1) * group["lr"], step_of_this_grad) + + p.data.add_(-step_size, exp_avg) + if group["weight_decay"] != 0: + p.data.add_(-group["weight_decay"], p.data) + + return loss |