summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/barlow_twins.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py
index a9466ab..36638e7 100644
--- a/text_recognizer/models/barlow_twins.py
+++ b/text_recognizer/models/barlow_twins.py
@@ -1,7 +1,6 @@
"""PyTorch Lightning Barlow Twins model."""
from typing import Tuple, Type
import attr
-import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
@@ -22,10 +21,9 @@ def off_diagonal(x: Tensor) -> Tensor:
class BarlowTwinsLitModel(BaseLitModel):
"""Barlow Twins training proceduer."""
- encoder: Type[nn.Module] = attr.ib()
projector: Projector = attr.ib()
lambda_: float = attr.ib()
- augment: T.Compose = attr.ib()
+ augment: Type[T.Compose] = attr.ib()
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
@@ -42,7 +40,7 @@ class BarlowTwinsLitModel(BaseLitModel):
def forward(self, data: Tensor) -> Tensor:
"""Encodes image to projector latent."""
- z = self.encoder(data)
+ z = self.network(data)
z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1)
z_p = self.projector(z_e)
return z_p