diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/barlow_twins.py | 71 | ||||
-rw-r--r-- | text_recognizer/networks/barlow_twins/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/barlow_twins/projector.py | 36 |
3 files changed, 108 insertions, 0 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py new file mode 100644 index 0000000..d044ba2 --- /dev/null +++ b/text_recognizer/models/barlow_twins.py @@ -0,0 +1,71 @@ +"""PyTorch Lightning Barlow Twins model.""" +from typing import Type +import attr +import pytorch_lightning as pl +import torch +from torch import nn +from torch import Tensor +import torchvision.transforms as T + +from text_recognizer.models.base import BaseLitModel +from text_recognizer.networks.barlow_twins.projector import Projector + + +def off_diagonal(x: Tensor) -> Tensor: + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +@attr.s(auto_attribs=True, eq=False) +class BarlowTwinsLitModel(BaseLitModel): + """Barlow Twins training proceduer.""" + + encoder: Type[nn.Module] = attr.ib() + projector: Projector = attr.ib() + lambda_: float = attr.ib() + augment: T.Compose = attr.ib() + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.bn = nn.BatchNorm1d(self.projector.dims[-1], affine=False) + + def loss_fn(self, z1: Tensor, z2: Tensor) -> Tensor: + """Calculates the Barlow Twin loss.""" + c = torch.mm(self.bn(z1), self.bn(z2)) + c.div_(z1.shape[0]) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + return on_diag + self.lambda_ * off_diag + + def forward(self, data: Tensor) -> Tensor: + """Encodes image to projector latent.""" + z_e = self.encoder(data).flatten(start_dim=1) + z_p = self.projector(z_e) + return z_p + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, _ = batch + x1, x2 = self.augment(data), self.augment(data) + z1, z2 = self(x1), self(x2) + loss = self.loss_fn(z1, z2) + self.log("train/loss", loss) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, _ = batch + x1, x2 = self.augment(data), self.augment(data) + z1, z2 = self(x1), self(x2) + loss = self.loss_fn(z1, z2) + self.log("val/loss", loss, prog_bar=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, _ = batch + x1, x2 = self.augment(data), self.augment(data) + z1, z2 = self(x1), self(x2) + loss = self.loss_fn(z1, z2) + self.log("test/loss", loss, prog_bar=True) diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py new file mode 100644 index 0000000..0b74818 --- /dev/null +++ b/text_recognizer/networks/barlow_twins/__init__.py @@ -0,0 +1 @@ +"""Module for projector network in Barlow Twins.""" diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py new file mode 100644 index 0000000..05d5e2e --- /dev/null +++ b/text_recognizer/networks/barlow_twins/projector.py @@ -0,0 +1,36 @@ +"""Projector network in Barlow Twins.""" + +from typing import List +import torch +from torch import nn +from torch import Tensor + + +class Projector(nn.Module): + """MLP network.""" + + def __init__(self, dims: List[int]) -> None: + super().__init__() + self.dims = dims + self.network = self._build() + + def _build(self) -> nn.Sequential: + """Builds projector network.""" + layers = [ + nn.Sequential( + nn.Linear( + in_features=self.dims[i], out_features=self.dims[i + 1], bias=False + ), + nn.BatchNorm1d(self.dims[i + 1]), + nn.ReLU(inplace=True), + ) + for i in range(len(self.dims) - 2) + ] + layers.append( + nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False) + ) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + """Project latent to higher dimesion.""" + return self.network(x) |