From 4d1f2cef39688871d2caafce42a09316381a27ae Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Jul 2021 23:05:25 +0200 Subject: Refactor with attr, working on cnn+transformer network --- text_recognizer/networks/loss/__init__.py | 2 -- .../networks/loss/label_smoothing_loss.py | 42 ---------------------- 2 files changed, 44 deletions(-) delete mode 100644 text_recognizer/networks/loss/__init__.py delete mode 100644 text_recognizer/networks/loss/label_smoothing_loss.py (limited to 'text_recognizer/networks/loss') diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py deleted file mode 100644 index cb83608..0000000 --- a/text_recognizer/networks/loss/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Loss module.""" -from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py deleted file mode 100644 index 40a7609..0000000 --- a/text_recognizer/networks/loss/label_smoothing_loss.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class LabelSmoothingLoss(nn.Module): - """Label smoothing cross entropy loss.""" - - def __init__( - self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 - ) -> None: - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index - super().__init__() - - smoothing_value = label_smoothing / (vocab_size - 2) - one_hot = torch.full((vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - - self.confidence = 1.0 - label_smoothing - - def forward(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): Predictions from the network. - targets (Tensor): Ground truth. - - Shapes: - outpus: Batch size x num classes - targets: Batch size - - Returns: - Tensor: Label smoothing loss. - """ - model_prob = self.one_hot.repeat(targets.size(0), 1) - model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) - model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) - return F.kl_div(output, model_prob, reduction="sum") -- cgit v1.2.3-70-g09d2