summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-07 08:56:40 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-07 08:56:40 +0200
commit484dc2b09c87729b4e777e94efdd2e7583651df9 (patch)
treedc96e4c5bf8d1a171aa087bd518588baacabce80 /text_recognizer/networks
parent947d0209547cb4fcb95f47e8b8a47856092d7978 (diff)
Add Barlow Twins network and training proceduer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/barlow_twins/__init__.py1
-rw-r--r--text_recognizer/networks/barlow_twins/projector.py36
2 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py
new file mode 100644
index 0000000..0b74818
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/__init__.py
@@ -0,0 +1 @@
+"""Module for projector network in Barlow Twins."""
diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py
new file mode 100644
index 0000000..05d5e2e
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/projector.py
@@ -0,0 +1,36 @@
+"""Projector network in Barlow Twins."""
+
+from typing import List
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class Projector(nn.Module):
+ """MLP network."""
+
+ def __init__(self, dims: List[int]) -> None:
+ super().__init__()
+ self.dims = dims
+ self.network = self._build()
+
+ def _build(self) -> nn.Sequential:
+ """Builds projector network."""
+ layers = [
+ nn.Sequential(
+ nn.Linear(
+ in_features=self.dims[i], out_features=self.dims[i + 1], bias=False
+ ),
+ nn.BatchNorm1d(self.dims[i + 1]),
+ nn.ReLU(inplace=True),
+ )
+ for i in range(len(self.dims) - 2)
+ ]
+ layers.append(
+ nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False)
+ )
+ return nn.Sequential(*layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Project latent to higher dimesion."""
+ return self.network(x)