summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/losses.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/losses.py')
-rw-r--r--src/text_recognizer/networks/losses.py31
1 files changed, 0 insertions, 31 deletions
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py
deleted file mode 100644
index 73e0641..0000000
--- a/src/text_recognizer/networks/losses.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
-from torch import nn
-from torch import Tensor
-
-
-class EmbeddingLoss:
- """Metric loss for training encoders to produce information-rich latent embeddings."""
-
- def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
- self.distance = distances.CosineSimilarity()
- self.reducer = reducers.ThresholdReducer(low=0)
- self.loss_fn = losses.TripletMarginLoss(
- margin=margin, distance=self.distance, reducer=self.reducer
- )
- self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
- def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
- """Computes the metric loss for the embeddings based on their labels.
-
- Args:
- embeddings (Tensor): The laten vectors encoded by the network.
- labels (Tensor): Labels of the embeddings.
-
- Returns:
- Tensor: The metric loss for the embeddings.
-
- """
- hard_pairs = self.miner(embeddings, labels)
- loss = self.loss_fn(embeddings, labels, hard_pairs)
- return loss