From 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Apr 2021 20:47:55 +0200 Subject: Add new training loop with PyTorch Lightning, remove stale files --- text_recognizer/networks/loss/loss.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) (limited to 'text_recognizer/networks/loss/loss.py') diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py index cf9fa0d..d12dc9c 100644 --- a/text_recognizer/networks/loss/loss.py +++ b/text_recognizer/networks/loss/loss.py @@ -1,39 +1,9 @@ """Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers import torch from torch import nn from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss +__all__ = ["LabelSmoothingCrossEntropy"] class LabelSmoothingCrossEntropy(nn.Module): -- cgit v1.2.3-70-g09d2