diff options
Diffstat (limited to 'text_recognizer/networks/loss')
-rw-r--r-- | text_recognizer/networks/loss/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/loss/loss.py | 32 |
2 files changed, 2 insertions, 32 deletions
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py index b489264..cb83608 100644 --- a/text_recognizer/networks/loss/__init__.py +++ b/text_recognizer/networks/loss/__init__.py @@ -1,2 +1,2 @@ """Loss module.""" -from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy +from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py index cf9fa0d..d12dc9c 100644 --- a/text_recognizer/networks/loss/loss.py +++ b/text_recognizer/networks/loss/loss.py @@ -1,39 +1,9 @@ """Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers import torch from torch import nn from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss +__all__ = ["LabelSmoothingCrossEntropy"] class LabelSmoothingCrossEntropy(nn.Module): |