From 4d7713746eb936832e84852e90292936b933e87d Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 22 Oct 2020 22:45:58 +0200 Subject: Transfomer added, many other changes. --- src/text_recognizer/networks/loss.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 src/text_recognizer/networks/loss.py (limited to 'src/text_recognizer/networks/loss.py') diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py new file mode 100644 index 0000000..ff843cf --- /dev/null +++ b/src/text_recognizer/networks/loss.py @@ -0,0 +1,34 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +from torch import nn +from torch import Tensor + + +__all__ = ["EmbeddingLoss"] + + +class EmbeddingLoss: + """Metric loss for training encoders to produce information-rich latent embeddings.""" + + def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: + self.distance = distances.CosineSimilarity() + self.reducer = reducers.ThresholdReducer(low=0) + self.loss_fn = losses.TripletMarginLoss( + margin=margin, distance=self.distance, reducer=self.reducer + ) + self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + + def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: + """Computes the metric loss for the embeddings based on their labels. + + Args: + embeddings (Tensor): The laten vectors encoded by the network. + labels (Tensor): Labels of the embeddings. + + Returns: + Tensor: The metric loss for the embeddings. + + """ + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_fn(embeddings, labels, hard_pairs) + return loss -- cgit v1.2.3-70-g09d2