diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:08:01 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:08:01 +0200 |
commit | c564117e02b0cd13896e044ba61265149780a406 (patch) | |
tree | 1030dca23bca7719bc2ac902711a36ad6d47dc5c /text_recognizer/models/barlow_twins.py | |
parent | 793f1ee1a653dd6a4eb47dab357aa5c0e2a9eb72 (diff) |
Update barlow model
Diffstat (limited to 'text_recognizer/models/barlow_twins.py')
-rw-r--r-- | text_recognizer/models/barlow_twins.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index a9466ab..36638e7 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,7 +1,6 @@ """PyTorch Lightning Barlow Twins model.""" from typing import Tuple, Type import attr -import pytorch_lightning as pl import torch from torch import nn from torch import Tensor @@ -22,10 +21,9 @@ def off_diagonal(x: Tensor) -> Tensor: class BarlowTwinsLitModel(BaseLitModel): """Barlow Twins training proceduer.""" - encoder: Type[nn.Module] = attr.ib() projector: Projector = attr.ib() lambda_: float = attr.ib() - augment: T.Compose = attr.ib() + augment: Type[T.Compose] = attr.ib() def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -42,7 +40,7 @@ class BarlowTwinsLitModel(BaseLitModel): def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z = self.encoder(data) + z = self.network(data) z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) z_p = self.projector(z_e) return z_p |