From 484dc2b09c87729b4e777e94efdd2e7583651df9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 7 Oct 2021 08:56:40 +0200 Subject: Add Barlow Twins network and training proceduer --- text_recognizer/networks/barlow_twins/__init__.py | 1 + text_recognizer/networks/barlow_twins/projector.py | 36 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 text_recognizer/networks/barlow_twins/__init__.py create mode 100644 text_recognizer/networks/barlow_twins/projector.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py new file mode 100644 index 0000000..0b74818 --- /dev/null +++ b/text_recognizer/networks/barlow_twins/__init__.py @@ -0,0 +1 @@ +"""Module for projector network in Barlow Twins.""" diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py new file mode 100644 index 0000000..05d5e2e --- /dev/null +++ b/text_recognizer/networks/barlow_twins/projector.py @@ -0,0 +1,36 @@ +"""Projector network in Barlow Twins.""" + +from typing import List +import torch +from torch import nn +from torch import Tensor + + +class Projector(nn.Module): + """MLP network.""" + + def __init__(self, dims: List[int]) -> None: + super().__init__() + self.dims = dims + self.network = self._build() + + def _build(self) -> nn.Sequential: + """Builds projector network.""" + layers = [ + nn.Sequential( + nn.Linear( + in_features=self.dims[i], out_features=self.dims[i + 1], bias=False + ), + nn.BatchNorm1d(self.dims[i + 1]), + nn.ReLU(inplace=True), + ) + for i in range(len(self.dims) - 2) + ] + layers.append( + nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False) + ) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + """Project latent to higher dimesion.""" + return self.network(x) -- cgit v1.2.3-70-g09d2