diff options
Diffstat (limited to 'text_recognizer/models/barlow_twins.py')
-rw-r--r-- | text_recognizer/models/barlow_twins.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index d044ba2..a9466ab 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,10 +1,11 @@ """PyTorch Lightning Barlow Twins model.""" -from typing import Type +from typing import Tuple, Type import attr import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torch.nn.functional as F import torchvision.transforms as T from text_recognizer.models.base import BaseLitModel @@ -41,7 +42,8 @@ class BarlowTwinsLitModel(BaseLitModel): def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z_e = self.encoder(data).flatten(start_dim=1) + z = self.encoder(data) + z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) z_p = self.projector(z_e) return z_p |