diff options
-rw-r--r-- | text_recognizer/models/barlow_twins.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index a9466ab..36638e7 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,7 +1,6 @@ """PyTorch Lightning Barlow Twins model.""" from typing import Tuple, Type import attr -import pytorch_lightning as pl import torch from torch import nn from torch import Tensor @@ -22,10 +21,9 @@ def off_diagonal(x: Tensor) -> Tensor: class BarlowTwinsLitModel(BaseLitModel): """Barlow Twins training proceduer.""" - encoder: Type[nn.Module] = attr.ib() projector: Projector = attr.ib() lambda_: float = attr.ib() - augment: T.Compose = attr.ib() + augment: Type[T.Compose] = attr.ib() def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -42,7 +40,7 @@ class BarlowTwinsLitModel(BaseLitModel): def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z = self.encoder(data) + z = self.network(data) z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) z_p = self.projector(z_e) return z_p |