diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
commit | 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch) | |
tree | 4fe2bcd82553c8062eb0908ae6442c123addf55d /text_recognizer/networks/loss/loss.py | |
parent | 9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff) |
Add new training loop with PyTorch Lightning, remove stale files
Diffstat (limited to 'text_recognizer/networks/loss/loss.py')
-rw-r--r-- | text_recognizer/networks/loss/loss.py | 32 |
1 files changed, 1 insertions, 31 deletions
diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py index cf9fa0d..d12dc9c 100644 --- a/text_recognizer/networks/loss/loss.py +++ b/text_recognizer/networks/loss/loss.py @@ -1,39 +1,9 @@ """Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers import torch from torch import nn from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss +__all__ = ["LabelSmoothingCrossEntropy"] class LabelSmoothingCrossEntropy(nn.Module): |