summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/loss.py')
-rw-r--r--src/text_recognizer/networks/loss.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py
new file mode 100644
index 0000000..ff843cf
--- /dev/null
+++ b/src/text_recognizer/networks/loss.py
@@ -0,0 +1,34 @@
+"""Implementations of custom loss functions."""
+from pytorch_metric_learning import distances, losses, miners, reducers
+from torch import nn
+from torch import Tensor
+
+
+__all__ = ["EmbeddingLoss"]
+
+
+class EmbeddingLoss:
+ """Metric loss for training encoders to produce information-rich latent embeddings."""
+
+ def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
+ self.distance = distances.CosineSimilarity()
+ self.reducer = reducers.ThresholdReducer(low=0)
+ self.loss_fn = losses.TripletMarginLoss(
+ margin=margin, distance=self.distance, reducer=self.reducer
+ )
+ self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
+
+ def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
+ """Computes the metric loss for the embeddings based on their labels.
+
+ Args:
+ embeddings (Tensor): The laten vectors encoded by the network.
+ labels (Tensor): Labels of the embeddings.
+
+ Returns:
+ Tensor: The metric loss for the embeddings.
+
+ """
+ hard_pairs = self.miner(embeddings, labels)
+ loss = self.loss_fn(embeddings, labels, hard_pairs)
+ return loss