diff options
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/barlow_twins.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py new file mode 100644 index 0000000..fe30b22 --- /dev/null +++ b/text_recognizer/criterions/barlow_twins.py @@ -0,0 +1,26 @@ +"""Barlow twins loss function.""" + +import torch +from torch import nn, Tensor + + +def off_diagonal(x: Tensor) -> Tensor: + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +class BarlowTwinsLoss(nn.Module): + def __init__(self, dim: int, lambda_: float) -> None: + super().__init__() + self.bn = nn.BatchNorm1d(dim, affine=False) + self.lambda_ = lambda_ + + def forward(self, z1: Tensor, z2: Tensor) -> Tensor: + """Calculates the Barlow Twin loss.""" + c = self.bn(z1).T @ self.bn(z2) + c.div_(z1.shape[0]) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + return on_diag + self.lambda_ * off_diag |