diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:27:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:27:24 +0200 |
commit | 9a8044f4a3826a119416665741b709cd686fca87 (patch) | |
tree | e339593bb4e3858fa9379d14752dc52bf5949825 /text_recognizer/criterions | |
parent | ae8bfa62f0e02bd70c27bc1e71697249a5a79e7e (diff) |
Remove Barlow Twins
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/barlow_twins.py | 26 |
1 files changed, 0 insertions, 26 deletions
diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py deleted file mode 100644 index fe30b22..0000000 --- a/text_recognizer/criterions/barlow_twins.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Barlow twins loss function.""" - -import torch -from torch import nn, Tensor - - -def off_diagonal(x: Tensor) -> Tensor: - n, m = x.shape - assert n == m - return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() - - -class BarlowTwinsLoss(nn.Module): - def __init__(self, dim: int, lambda_: float) -> None: - super().__init__() - self.bn = nn.BatchNorm1d(dim, affine=False) - self.lambda_ = lambda_ - - def forward(self, z1: Tensor, z2: Tensor) -> Tensor: - """Calculates the Barlow Twin loss.""" - c = self.bn(z1).T @ self.bn(z2) - c.div_(z1.shape[0]) - - on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() - off_diag = off_diagonal(c).pow_(2).sum() - return on_diag + self.lambda_ * off_diag |