summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/loss
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/networks/loss
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/networks/loss')
-rw-r--r--src/text_recognizer/networks/loss/__init__.py2
-rw-r--r--src/text_recognizer/networks/loss/loss.py69
2 files changed, 0 insertions, 71 deletions
diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py
deleted file mode 100644
index b489264..0000000
--- a/src/text_recognizer/networks/loss/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Loss module."""
-from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy
diff --git a/src/text_recognizer/networks/loss/loss.py b/src/text_recognizer/networks/loss/loss.py
deleted file mode 100644
index cf9fa0d..0000000
--- a/src/text_recognizer/networks/loss/loss.py
+++ /dev/null
@@ -1,69 +0,0 @@
-"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
-import torch
-from torch import nn
-from torch import Tensor
-from torch.autograd import Variable
-import torch.nn.functional as F
-
-__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
-
-
-class EmbeddingLoss:
- """Metric loss for training encoders to produce information-rich latent embeddings."""
-
- def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
- self.distance = distances.CosineSimilarity()
- self.reducer = reducers.ThresholdReducer(low=0)
- self.loss_fn = losses.TripletMarginLoss(
- margin=margin, distance=self.distance, reducer=self.reducer
- )
- self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
- def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
- """Computes the metric loss for the embeddings based on their labels.
-
- Args:
- embeddings (Tensor): The laten vectors encoded by the network.
- labels (Tensor): Labels of the embeddings.
-
- Returns:
- Tensor: The metric loss for the embeddings.
-
- """
- hard_pairs = self.miner(embeddings, labels)
- loss = self.loss_fn(embeddings, labels, hard_pairs)
- return loss
-
-
-class LabelSmoothingCrossEntropy(nn.Module):
- """Label smoothing loss function."""
-
- def __init__(
- self,
- classes: int,
- smoothing: float = 0.0,
- ignore_index: int = None,
- dim: int = -1,
- ) -> None:
- super().__init__()
- self.confidence = 1.0 - smoothing
- self.smoothing = smoothing
- self.ignore_index = ignore_index
- self.cls = classes
- self.dim = dim
-
- def forward(self, pred: Tensor, target: Tensor) -> Tensor:
- """Calculates the loss."""
- pred = pred.log_softmax(dim=self.dim)
- with torch.no_grad():
- # true_dist = pred.data.clone()
- true_dist = torch.zeros_like(pred)
- true_dist.fill_(self.smoothing / (self.cls - 1))
- true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
- if self.ignore_index is not None:
- true_dist[:, self.ignore_index] = 0
- mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
- if mask.dim() > 0:
- true_dist.index_fill_(0, mask.squeeze(), 0.0)
- return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))