diff options
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/barlow_twins.py | 26 |
1 files changed, 0 insertions, 26 deletions
diff --git a/text_recognizer/criterions/barlow_twins.py b/text_recognizer/criterions/barlow_twins.py deleted file mode 100644 index fe30b22..0000000 --- a/text_recognizer/criterions/barlow_twins.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Barlow twins loss function.""" - -import torch -from torch import nn, Tensor - - -def off_diagonal(x: Tensor) -> Tensor: - n, m = x.shape - assert n == m - return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() - - -class BarlowTwinsLoss(nn.Module): - def __init__(self, dim: int, lambda_: float) -> None: - super().__init__() - self.bn = nn.BatchNorm1d(dim, affine=False) - self.lambda_ = lambda_ - - def forward(self, z1: Tensor, z2: Tensor) -> Tensor: - """Calculates the Barlow Twin loss.""" - c = self.bn(z1).T @ self.bn(z2) - c.div_(z1.shape[0]) - - on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() - off_diag = off_diagonal(c).pow_(2).sum() - return on_diag + self.lambda_ * off_diag |