summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/barlow_twins/projector.py
blob: 05d5e2efdeef720ce140dd9aa83df01b69191ee7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Projector network in Barlow Twins."""

from typing import List
import torch
from torch import nn
from torch import Tensor


class Projector(nn.Module):
    """MLP network."""

    def __init__(self, dims: List[int]) -> None:
        super().__init__()
        self.dims = dims
        self.network = self._build()

    def _build(self) -> nn.Sequential:
        """Builds projector network."""
        layers = [
            nn.Sequential(
                nn.Linear(
                    in_features=self.dims[i], out_features=self.dims[i + 1], bias=False
                ),
                nn.BatchNorm1d(self.dims[i + 1]),
                nn.ReLU(inplace=True),
            )
            for i in range(len(self.dims) - 2)
        ]
        layers.append(
            nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False)
        )
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        """Project latent to higher dimesion."""
        return self.network(x)