diff options
| -rw-r--r-- | text_recognizer/criterions/barlow_twins.py | 26 | 
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py new file mode 100644 index 0000000..fe30b22 --- /dev/null +++ b/text_recognizer/criterions/barlow_twins.py @@ -0,0 +1,26 @@ +"""Barlow twins loss function.""" + +import torch +from torch import nn, Tensor + + +def off_diagonal(x: Tensor) -> Tensor: +    n, m = x.shape +    assert n == m +    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +class BarlowTwinsLoss(nn.Module): +    def __init__(self, dim: int, lambda_: float) -> None: +        super().__init__() +        self.bn = nn.BatchNorm1d(dim, affine=False) +        self.lambda_ = lambda_ + +    def forward(self, z1: Tensor, z2: Tensor) -> Tensor: +        """Calculates the Barlow Twin loss.""" +        c = self.bn(z1).T @ self.bn(z2) +        c.div_(z1.shape[0]) + +        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() +        off_diag = off_diagonal(c).pow_(2).sum() +        return on_diag + self.lambda_ * off_diag  |