diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-07 22:06:58 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-07 22:06:58 +0200 |
commit | 67269651f1ba013079ced12d5a7cd896fae59c5b (patch) | |
tree | 9942fd09656dcca562f0421f83c9aac259c11df4 /text_recognizer/models/barlow_twins.py | |
parent | cde2883054df7a2088e817e548c4b601b47633b6 (diff) |
Add adaptive pool between encoder and projector
Diffstat (limited to 'text_recognizer/models/barlow_twins.py')
-rw-r--r-- | text_recognizer/models/barlow_twins.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py index d044ba2..a9466ab 100644 --- a/text_recognizer/models/barlow_twins.py +++ b/text_recognizer/models/barlow_twins.py @@ -1,10 +1,11 @@ """PyTorch Lightning Barlow Twins model.""" -from typing import Type +from typing import Tuple, Type import attr import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torch.nn.functional as F import torchvision.transforms as T from text_recognizer.models.base import BaseLitModel @@ -41,7 +42,8 @@ class BarlowTwinsLitModel(BaseLitModel): def forward(self, data: Tensor) -> Tensor: """Encodes image to projector latent.""" - z_e = self.encoder(data).flatten(start_dim=1) + z = self.encoder(data) + z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) z_p = self.projector(z_e) return z_p |