summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-07 22:06:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-07 22:06:58 +0200
commit67269651f1ba013079ced12d5a7cd896fae59c5b (patch)
tree9942fd09656dcca562f0421f83c9aac259c11df4 /text_recognizer/models
parentcde2883054df7a2088e817e548c4b601b47633b6 (diff)
Add adaptive pool between encoder and projector
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/barlow_twins.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/text_recognizer/models/barlow_twins.py b/text_recognizer/models/barlow_twins.py
index d044ba2..a9466ab 100644
--- a/text_recognizer/models/barlow_twins.py
+++ b/text_recognizer/models/barlow_twins.py
@@ -1,10 +1,11 @@
"""PyTorch Lightning Barlow Twins model."""
-from typing import Type
+from typing import Tuple, Type
import attr
import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
import torchvision.transforms as T
from text_recognizer.models.base import BaseLitModel
@@ -41,7 +42,8 @@ class BarlowTwinsLitModel(BaseLitModel):
def forward(self, data: Tensor) -> Tensor:
"""Encodes image to projector latent."""
- z_e = self.encoder(data).flatten(start_dim=1)
+ z = self.encoder(data)
+ z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1)
z_p = self.projector(z_e)
return z_p