diff options
Diffstat (limited to 'text_recognizer/networks/barlow_twins')
-rw-r--r-- | text_recognizer/networks/barlow_twins/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/barlow_twins/projector.py | 36 |
2 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py new file mode 100644 index 0000000..0b74818 --- /dev/null +++ b/text_recognizer/networks/barlow_twins/__init__.py @@ -0,0 +1 @@ +"""Module for projector network in Barlow Twins.""" diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py new file mode 100644 index 0000000..05d5e2e --- /dev/null +++ b/text_recognizer/networks/barlow_twins/projector.py @@ -0,0 +1,36 @@ +"""Projector network in Barlow Twins.""" + +from typing import List +import torch +from torch import nn +from torch import Tensor + + +class Projector(nn.Module): + """MLP network.""" + + def __init__(self, dims: List[int]) -> None: + super().__init__() + self.dims = dims + self.network = self._build() + + def _build(self) -> nn.Sequential: + """Builds projector network.""" + layers = [ + nn.Sequential( + nn.Linear( + in_features=self.dims[i], out_features=self.dims[i + 1], bias=False + ), + nn.BatchNorm1d(self.dims[i + 1]), + nn.ReLU(inplace=True), + ) + for i in range(len(self.dims) - 2) + ] + layers.append( + nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False) + ) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + """Project latent to higher dimesion.""" + return self.network(x) |