diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
commit | 8fdb6435e15703fa5b76df19728d905650ee1aef (patch) | |
tree | be3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer/networks/loss | |
parent | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff) | |
parent | 6cb08a110620ee09fe9d8a5d008197a801d025df (diff) |
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer/networks/loss')
-rw-r--r-- | src/text_recognizer/networks/loss/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/loss/loss.py | 69 |
2 files changed, 71 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py new file mode 100644 index 0000000..b489264 --- /dev/null +++ b/src/text_recognizer/networks/loss/__init__.py @@ -0,0 +1,2 @@ +"""Loss module.""" +from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy diff --git a/src/text_recognizer/networks/loss/loss.py b/src/text_recognizer/networks/loss/loss.py new file mode 100644 index 0000000..cf9fa0d --- /dev/null +++ b/src/text_recognizer/networks/loss/loss.py @@ -0,0 +1,69 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +import torch +from torch import nn +from torch import Tensor +from torch.autograd import Variable +import torch.nn.functional as F + +__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] + + +class EmbeddingLoss: + """Metric loss for training encoders to produce information-rich latent embeddings.""" + + def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: + self.distance = distances.CosineSimilarity() + self.reducer = reducers.ThresholdReducer(low=0) + self.loss_fn = losses.TripletMarginLoss( + margin=margin, distance=self.distance, reducer=self.reducer + ) + self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + + def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: + """Computes the metric loss for the embeddings based on their labels. + + Args: + embeddings (Tensor): The laten vectors encoded by the network. + labels (Tensor): Labels of the embeddings. + + Returns: + Tensor: The metric loss for the embeddings. + + """ + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_fn(embeddings, labels, hard_pairs) + return loss + + +class LabelSmoothingCrossEntropy(nn.Module): + """Label smoothing loss function.""" + + def __init__( + self, + classes: int, + smoothing: float = 0.0, + ignore_index: int = None, + dim: int = -1, + ) -> None: + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.ignore_index = ignore_index + self.cls = classes + self.dim = dim + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """Calculates the loss.""" + pred = pred.log_softmax(dim=self.dim) + with torch.no_grad(): + # true_dist = pred.data.clone() + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (self.cls - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + if self.ignore_index is not None: + true_dist[:, self.ignore_index] = 0 + mask = torch.nonzero(target == self.ignore_index, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) |