summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions/barlow_twins.py
blob: fe30b22a3543d51cc717ebddb6a5ea8325a60491 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Barlow twins loss function."""

import torch
from torch import nn, Tensor


def off_diagonal(x: Tensor) -> Tensor:
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class BarlowTwinsLoss(nn.Module):
    def __init__(self, dim: int, lambda_: float) -> None:
        super().__init__()
        self.bn = nn.BatchNorm1d(dim, affine=False)
        self.lambda_ = lambda_

    def forward(self, z1: Tensor, z2: Tensor) -> Tensor:
        """Calculates the Barlow Twin loss."""
        c = self.bn(z1).T @ self.bn(z2)
        c.div_(z1.shape[0])

        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        return on_diag + self.lambda_ * off_diag