From 2c377b6f7e2d4ba8a7c424c748053cc8d560599a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Oct 2021 22:09:40 +0200 Subject: Refactor barlow twins lit model --- text_recognizer/models/barlow_twins.py | 40 ++++++---------------------------- 1 file changed, 7 insertions(+), 33 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index 36638e7..6e2719d 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,54 +1,28 @@ """PyTorch Lightning Barlow Twins model.""" from typing import Tuple, Type import attr -import torch from torch import nn from torch import Tensor -import torch.nn.functional as F -import torchvision.transforms as T from text_recognizer.models.base import BaseLitModel -from text_recognizer.networks.barlow_twins.projector import Projector - - -def off_diagonal(x: Tensor) -> Tensor: - n, m = x.shape - assert n == m - return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() +from text_recognizer.criterions.barlow_twins import BarlowTwinsLoss @attr.s(auto_attribs=True, eq=False) class BarlowTwinsLitModel(BaseLitModel): """Barlow Twins training proceduer.""" - projector: Projector = attr.ib() - lambda_: float = attr.ib() - augment: Type[T.Compose] = attr.ib() - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self.bn = nn.BatchNorm1d(self.projector.dims[-1], affine=False) - - def loss_fn(self, z1: Tensor, z2: Tensor) -> Tensor: - """Calculates the Barlow Twin loss.""" - c = torch.mm(self.bn(z1), self.bn(z2)) - c.div_(z1.shape[0]) - - on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() - off_diag = off_diagonal(c).pow_(2).sum() - return on_diag + self.lambda_ * off_diag + network: Type[nn.Module] = attr.ib() + loss_fn: BarlowTwinsLoss = attr.ib() def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z = self.network(data) - z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) - z_p = self.projector(z_e) - return z_p + return self.network(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch - x1, x2 = self.augment(data), self.augment(data) + x1, x2 = data z1, z2 = self(x1), self(x2) loss = self.loss_fn(z1, z2) self.log("train/loss", loss) @@ -57,7 +31,7 @@ class BarlowTwinsLitModel(BaseLitModel): def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - x1, x2 = self.augment(data), self.augment(data) + x1, x2 = data z1, z2 = self(x1), self(x2) loss = self.loss_fn(z1, z2) self.log("val/loss", loss, prog_bar=True) @@ -65,7 +39,7 @@ class BarlowTwinsLitModel(BaseLitModel): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - x1, x2 = self.augment(data), self.augment(data) + x1, x2 = data z1, z2 = self(x1), self(x2) loss = self.loss_fn(z1, z2) self.log("test/loss", loss, prog_bar=True) -- cgit v1.2.3-70-g09d2