From 67269651f1ba013079ced12d5a7cd896fae59c5b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 7 Oct 2021 22:06:58 +0200 Subject: Add adaptive pool between encoder and projector --- text_recognizer/models/barlow_twins.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index d044ba2..a9466ab 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,10 +1,11 @@ """PyTorch Lightning Barlow Twins model.""" -from typing import Type +from typing import Tuple, Type import attr import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torch.nn.functional as F import torchvision.transforms as T from text_recognizer.models.base import BaseLitModel @@ -41,7 +42,8 @@ class BarlowTwinsLitModel(BaseLitModel): def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z_e = self.encoder(data).flatten(start_dim=1) + z = self.encoder(data) + z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) z_p = self.projector(z_e) return z_p -- cgit v1.2.3-70-g09d2