From 9e0496ace20d5b7e3cde0dcfc1e8400039e51916 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Oct 2021 22:13:45 +0200 Subject: Add Barlow Loss --- text_recognizer/criterions/barlow_twins.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 text_recognizer/criterions/barlow_twins.py (limited to 'text_recognizer') diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py new file mode 100644 index 0000000..fe30b22 --- /dev/null +++ b/text_recognizer/criterions/barlow_twins.py @@ -0,0 +1,26 @@ +"""Barlow twins loss function.""" + +import torch +from torch import nn, Tensor + + +def off_diagonal(x: Tensor) -> Tensor: + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +class BarlowTwinsLoss(nn.Module): + def __init__(self, dim: int, lambda_: float) -> None: + super().__init__() + self.bn = nn.BatchNorm1d(dim, affine=False) + self.lambda_ = lambda_ + + def forward(self, z1: Tensor, z2: Tensor) -> Tensor: + """Calculates the Barlow Twin loss.""" + c = self.bn(z1).T @ self.bn(z2) + c.div_(z1.shape[0]) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + return on_diag + self.lambda_ * off_diag -- cgit v1.2.3-70-g09d2