summaryrefslogtreecommitdiff
path: root/text_recognizer/models/barlow_twins.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:27:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:27:24 +0200
commit9a8044f4a3826a119416665741b709cd686fca87 (patch)
treee339593bb4e3858fa9379d14752dc52bf5949825 /text_recognizer/models/barlow_twins.py
parentae8bfa62f0e02bd70c27bc1e71697249a5a79e7e (diff)
Remove Barlow Twins
Diffstat (limited to 'text_recognizer/models/barlow_twins.py')
-rw-r--r--text_recognizer/models/barlow_twins.py45
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py
deleted file mode 100644
index 6e2719d..0000000
--- a/text_recognizer/models/barlow_twins.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""PyTorch Lightning Barlow Twins model."""
-from typing import Tuple, Type
-import attr
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.models.base import BaseLitModel
-from text_recognizer.criterions.barlow_twins import BarlowTwinsLoss
-
-
-@attr.s(auto_attribs=True, eq=False)
-class BarlowTwinsLitModel(BaseLitModel):
- """Barlow Twins training proceduer."""
-
- network: Type[nn.Module] = attr.ib()
- loss_fn: BarlowTwinsLoss = attr.ib()
-
- def forward(self, data: Tensor) -> Tensor:
- """Encodes image to projector latent."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, _ = batch
- x1, x2 = data
- z1, z2 = self(x1), self(x2)
- loss = self.loss_fn(z1, z2)
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- x1, x2 = data
- z1, z2 = self(x1), self(x2)
- loss = self.loss_fn(z1, z2)
- self.log("val/loss", loss, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- x1, x2 = data
- z1, z2 = self(x1), self(x2)
- loss = self.loss_fn(z1, z2)
- self.log("test/loss", loss, prog_bar=True)