From 9a8044f4a3826a119416665741b709cd686fca87 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 27 Oct 2021 22:27:24 +0200 Subject: Remove Barlow Twins --- text_recognizer/models/barlow_twins.py | 45 ---------------------------------- 1 file changed, 45 deletions(-) delete mode 100644 text_recognizer/models/barlow_twins.py (limited to 'text_recognizer/models/barlow_twins.py') diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py deleted file mode 100644 index 6e2719d..0000000 --- a/text_recognizer/models/barlow_twins.py +++ /dev/null @@ -1,45 +0,0 @@ -"""PyTorch Lightning Barlow Twins model.""" -from typing import Tuple, Type -import attr -from torch import nn -from torch import Tensor - -from text_recognizer.models.base import BaseLitModel -from text_recognizer.criterions.barlow_twins import BarlowTwinsLoss - - -@attr.s(auto_attribs=True, eq=False) -class BarlowTwinsLitModel(BaseLitModel): - """Barlow Twins training proceduer.""" - - network: Type[nn.Module] = attr.ib() - loss_fn: BarlowTwinsLoss = attr.ib() - - def forward(self, data: Tensor) -> Tensor: - """Encodes image to projector latent.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, _ = batch - x1, x2 = data - z1, z2 = self(x1), self(x2) - loss = self.loss_fn(z1, z2) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - x1, x2 = data - z1, z2 = self(x1), self(x2) - loss = self.loss_fn(z1, z2) - self.log("val/loss", loss, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - x1, x2 = data - z1, z2 = self(x1), self(x2) - loss = self.loss_fn(z1, z2) - self.log("test/loss", loss, prog_bar=True) -- cgit v1.2.3-70-g09d2