diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/barlow_twins.py | 45 |
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py deleted file mode 100644 index 6e2719d..0000000 --- a/text_recognizer/models/barlow_twins.py +++ /dev/null @@ -1,45 +0,0 @@ -"""PyTorch Lightning Barlow Twins model.""" -from typing import Tuple, Type -import attr -from torch import nn -from torch import Tensor - -from text_recognizer.models.base import BaseLitModel -from text_recognizer.criterions.barlow_twins import BarlowTwinsLoss - - -@attr.s(auto_attribs=True, eq=False) -class BarlowTwinsLitModel(BaseLitModel): - """Barlow Twins training proceduer.""" - - network: Type[nn.Module] = attr.ib() - loss_fn: BarlowTwinsLoss = attr.ib() - - def forward(self, data: Tensor) -> Tensor: - """Encodes image to projector latent.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, _ = batch - x1, x2 = data - z1, z2 = self(x1), self(x2) - loss = self.loss_fn(z1, z2) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - x1, x2 = data - z1, z2 = self(x1), self(x2) - loss = self.loss_fn(z1, z2) - self.log("val/loss", loss, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - x1, x2 = data - z1, z2 = self(x1), self(x2) - loss = self.loss_fn(z1, z2) - self.log("test/loss", loss, prog_bar=True) |