summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:13:45 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:13:45 +0200
commit9e0496ace20d5b7e3cde0dcfc1e8400039e51916 (patch)
treee194e74bc1871c8cfc4e13f26a566d330b57e9b1
parent7cdfb62c94d3cad808bfaf198272bcfb66734711 (diff)
Add Barlow Loss
-rw-r--r--text_recognizer/criterions/barlow_twins.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py
new file mode 100644
index 0000000..fe30b22
--- /dev/null
+++ b/text_recognizer/criterions/barlow_twins.py
@@ -0,0 +1,26 @@
+"""Barlow twins loss function."""
+
+import torch
+from torch import nn, Tensor
+
+
+def off_diagonal(x: Tensor) -> Tensor:
+ n, m = x.shape
+ assert n == m
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
+
+
+class BarlowTwinsLoss(nn.Module):
+ def __init__(self, dim: int, lambda_: float) -> None:
+ super().__init__()
+ self.bn = nn.BatchNorm1d(dim, affine=False)
+ self.lambda_ = lambda_
+
+ def forward(self, z1: Tensor, z2: Tensor) -> Tensor:
+ """Calculates the Barlow Twin loss."""
+ c = self.bn(z1).T @ self.bn(z2)
+ c.div_(z1.shape[0])
+
+ on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
+ off_diag = off_diagonal(c).pow_(2).sum()
+ return on_diag + self.lambda_ * off_diag