From 484dc2b09c87729b4e777e94efdd2e7583651df9 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 7 Oct 2021 08:56:40 +0200
Subject: Add Barlow Twins network and training proceduer

---
 text_recognizer/networks/barlow_twins/__init__.py  |  1 +
 text_recognizer/networks/barlow_twins/projector.py | 36 ++++++++++++++++++++++
 2 files changed, 37 insertions(+)
 create mode 100644 text_recognizer/networks/barlow_twins/__init__.py
 create mode 100644 text_recognizer/networks/barlow_twins/projector.py

(limited to 'text_recognizer/networks/barlow_twins')

diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py
new file mode 100644
index 0000000..0b74818
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/__init__.py
@@ -0,0 +1 @@
+"""Module for projector network in Barlow Twins."""
diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py
new file mode 100644
index 0000000..05d5e2e
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/projector.py
@@ -0,0 +1,36 @@
+"""Projector network in Barlow Twins."""
+
+from typing import List
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class Projector(nn.Module):
+    """MLP network."""
+
+    def __init__(self, dims: List[int]) -> None:
+        super().__init__()
+        self.dims = dims
+        self.network = self._build()
+
+    def _build(self) -> nn.Sequential:
+        """Builds projector network."""
+        layers = [
+            nn.Sequential(
+                nn.Linear(
+                    in_features=self.dims[i], out_features=self.dims[i + 1], bias=False
+                ),
+                nn.BatchNorm1d(self.dims[i + 1]),
+                nn.ReLU(inplace=True),
+            )
+            for i in range(len(self.dims) - 2)
+        ]
+        layers.append(
+            nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False)
+        )
+        return nn.Sequential(*layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Project latent to higher dimesion."""
+        return self.network(x)
-- 
cgit v1.2.3-70-g09d2