diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-11 22:13:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-11 22:13:45 +0200 |
commit | 9e0496ace20d5b7e3cde0dcfc1e8400039e51916 (patch) | |
tree | e194e74bc1871c8cfc4e13f26a566d330b57e9b1 /text_recognizer/criterions | |
parent | 7cdfb62c94d3cad808bfaf198272bcfb66734711 (diff) |
Add Barlow Loss
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/barlow_twins.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py new file mode 100644 index 0000000..fe30b22 --- /dev/null +++ b/text_recognizer/criterions/barlow_twins.py @@ -0,0 +1,26 @@ +"""Barlow twins loss function.""" + +import torch +from torch import nn, Tensor + + +def off_diagonal(x: Tensor) -> Tensor: + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +class BarlowTwinsLoss(nn.Module): + def __init__(self, dim: int, lambda_: float) -> None: + super().__init__() + self.bn = nn.BatchNorm1d(dim, affine=False) + self.lambda_ = lambda_ + + def forward(self, z1: Tensor, z2: Tensor) -> Tensor: + """Calculates the Barlow Twin loss.""" + c = self.bn(z1).T @ self.bn(z2) + c.div_(z1.shape[0]) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + return on_diag + self.lambda_ * off_diag |