diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
commit | bd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch) | |
tree | e55cb3744904f7c2a0348b100c7e92a65e538a16 /text_recognizer/criterions | |
parent | 75801019981492eedf9280cb352eea3d8e99b65f (diff) |
Training working, multiple bug fixes
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/label_smoothing.py | 38 |
1 files changed, 16 insertions, 22 deletions
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py index 40a7609..cc71c45 100644 --- a/text_recognizer/criterions/label_smoothing.py +++ b/text_recognizer/criterions/label_smoothing.py @@ -6,37 +6,31 @@ import torch.nn.functional as F class LabelSmoothingLoss(nn.Module): - """Label smoothing cross entropy loss.""" - - def __init__( - self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 - ) -> None: - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index + def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1): super().__init__() + assert 0.0 < smoothing <= 1.0 + self.ignore_index = ignore_index + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.dim = dim - smoothing_value = label_smoothing / (vocab_size - 2) - one_hot = torch.full((vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - - self.confidence = 1.0 - label_smoothing - - def forward(self, output: Tensor, targets: Tensor) -> Tensor: + def forward(self, output: Tensor, target: Tensor) -> Tensor: """Computes the loss. Args: - output (Tensor): Predictions from the network. + output (Tensor): outputictions from the network. targets (Tensor): Ground truth. Shapes: - outpus: Batch size x num classes - targets: Batch size + TBC Returns: Tensor: Label smoothing loss. """ - model_prob = self.one_hot.repeat(targets.size(0), 1) - model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) - model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) - return F.kl_div(output, model_prob, reduction="sum") + output = output.log_softmax(dim=self.dim) + with torch.no_grad(): + true_dist = torch.zeros_like(output) + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + true_dist.masked_fill_((target == 4).unsqueeze(1), 0) + true_dist += self.smoothing / output.size(self.dim) + return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) |