summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/barlow_twins/projector.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/barlow_twins/projector.py')
-rw-r--r--text_recognizer/networks/barlow_twins/projector.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py
new file mode 100644
index 0000000..05d5e2e
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/projector.py
@@ -0,0 +1,36 @@
+"""Projector network in Barlow Twins."""
+
+from typing import List
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class Projector(nn.Module):
+ """MLP network."""
+
+ def __init__(self, dims: List[int]) -> None:
+ super().__init__()
+ self.dims = dims
+ self.network = self._build()
+
+ def _build(self) -> nn.Sequential:
+ """Builds projector network."""
+ layers = [
+ nn.Sequential(
+ nn.Linear(
+ in_features=self.dims[i], out_features=self.dims[i + 1], bias=False
+ ),
+ nn.BatchNorm1d(self.dims[i + 1]),
+ nn.ReLU(inplace=True),
+ )
+ for i in range(len(self.dims) - 2)
+ ]
+ layers.append(
+ nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False)
+ )
+ return nn.Sequential(*layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Project latent to higher dimesion."""
+ return self.network(x)