summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/barlow_twins/network.py
blob: a3e37501b620dab3ca3ca3018e23896296584ca9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Barlow Twins network."""
from typing import Type

from torch import nn, Tensor
import torch.nn.functional as F


class BarlowTwins(nn.Module):
    def __init__(self, encoder: Type[nn.Module], projector: Type[nn.Module]) -> None:
        super().__init__()
        self.encoder = encoder
        self.projector = projector

    def forward(self, x: Tensor) -> Tensor:
        z = self.encoder(x)
        z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1)
        z_p = self.projector(z_e)
        return z_p