From c564117e02b0cd13896e044ba61265149780a406 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:08:01 +0200 Subject: Update barlow model --- text_recognizer/models/barlow_twins.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'text_recognizer/models/barlow_twins.py') diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index a9466ab..36638e7 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,7 +1,6 @@ """PyTorch Lightning Barlow Twins model.""" from typing import Tuple, Type import attr -import pytorch_lightning as pl import torch from torch import nn from torch import Tensor @@ -22,10 +21,9 @@ def off_diagonal(x: Tensor) -> Tensor: class BarlowTwinsLitModel(BaseLitModel): """Barlow Twins training proceduer.""" - encoder: Type[nn.Module] = attr.ib() projector: Projector = attr.ib() lambda_: float = attr.ib() - augment: T.Compose = attr.ib() + augment: Type[T.Compose] = attr.ib() def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -42,7 +40,7 @@ class BarlowTwinsLitModel(BaseLitModel): def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z = self.encoder(data) + z = self.network(data) z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) z_p = self.projector(z_e) return z_p -- cgit v1.2.3-70-g09d2