diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-07 08:56:40 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-07 08:56:40 +0200 |
commit | 484dc2b09c87729b4e777e94efdd2e7583651df9 (patch) | |
tree | dc96e4c5bf8d1a171aa087bd518588baacabce80 /text_recognizer/networks | |
parent | 947d0209547cb4fcb95f47e8b8a47856092d7978 (diff) |
Add Barlow Twins network and training proceduer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/barlow_twins/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/barlow_twins/projector.py | 36 |
2 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py new file mode 100644 index 0000000..0b74818 --- /dev/null +++ b/text_recognizer/networks/barlow_twins/__init__.py @@ -0,0 +1 @@ +"""Module for projector network in Barlow Twins.""" diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py new file mode 100644 index 0000000..05d5e2e --- /dev/null +++ b/text_recognizer/networks/barlow_twins/projector.py @@ -0,0 +1,36 @@ +"""Projector network in Barlow Twins.""" + +from typing import List +import torch +from torch import nn +from torch import Tensor + + +class Projector(nn.Module): + """MLP network.""" + + def __init__(self, dims: List[int]) -> None: + super().__init__() + self.dims = dims + self.network = self._build() + + def _build(self) -> nn.Sequential: + """Builds projector network.""" + layers = [ + nn.Sequential( + nn.Linear( + in_features=self.dims[i], out_features=self.dims[i + 1], bias=False + ), + nn.BatchNorm1d(self.dims[i + 1]), + nn.ReLU(inplace=True), + ) + for i in range(len(self.dims) - 2) + ] + layers.append( + nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False) + ) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + """Project latent to higher dimesion.""" + return self.network(x) |