blob: 05d5e2efdeef720ce140dd9aa83df01b69191ee7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
"""Projector network in Barlow Twins."""
from typing import List
import torch
from torch import nn
from torch import Tensor
class Projector(nn.Module):
"""MLP network."""
def __init__(self, dims: List[int]) -> None:
super().__init__()
self.dims = dims
self.network = self._build()
def _build(self) -> nn.Sequential:
"""Builds projector network."""
layers = [
nn.Sequential(
nn.Linear(
in_features=self.dims[i], out_features=self.dims[i + 1], bias=False
),
nn.BatchNorm1d(self.dims[i + 1]),
nn.ReLU(inplace=True),
)
for i in range(len(self.dims) - 2)
]
layers.append(
nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False)
)
return nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
"""Project latent to higher dimesion."""
return self.network(x)
|