summaryrefslogtreecommitdiff
path: root/text_recognizer/models/barlow_twins.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/barlow_twins.py')
-rw-r--r--text_recognizer/models/barlow_twins.py45
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py
deleted file mode 100644
index 6e2719d..0000000
--- a/text_recognizer/models/barlow_twins.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""PyTorch Lightning Barlow Twins model."""
-from typing import Tuple, Type
-import attr
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.models.base import BaseLitModel
-from text_recognizer.criterions.barlow_twins import BarlowTwinsLoss
-
-
-@attr.s(auto_attribs=True, eq=False)
-class BarlowTwinsLitModel(BaseLitModel):
- """Barlow Twins training proceduer."""
-
- network: Type[nn.Module] = attr.ib()
- loss_fn: BarlowTwinsLoss = attr.ib()
-
- def forward(self, data: Tensor) -> Tensor:
- """Encodes image to projector latent."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, _ = batch
- x1, x2 = data
- z1, z2 = self(x1), self(x2)
- loss = self.loss_fn(z1, z2)
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- x1, x2 = data
- z1, z2 = self(x1), self(x2)
- loss = self.loss_fn(z1, z2)
- self.log("val/loss", loss, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- x1, x2 = data
- z1, z2 = self(x1), self(x2)
- loss = self.loss_fn(z1, z2)
- self.log("test/loss", loss, prog_bar=True)