blob: a3e37501b620dab3ca3ca3018e23896296584ca9 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
"""Barlow Twins network."""
from typing import Type
from torch import nn, Tensor
import torch.nn.functional as F
class BarlowTwins(nn.Module):
def __init__(self, encoder: Type[nn.Module], projector: Type[nn.Module]) -> None:
super().__init__()
self.encoder = encoder
self.projector = projector
def forward(self, x: Tensor) -> Tensor:
z = self.encoder(x)
z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1)
z_p = self.projector(z_e)
return z_p
|