summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:13:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:13:54 +0200
commitca1925433861f6b1037bcd81112d56717d9f153b (patch)
treefdd43a527957c26f9c9425d574cb22fbc6e1f10a /text_recognizer
parent9e0496ace20d5b7e3cde0dcfc1e8400039e51916 (diff)
Add BarlowNetwork
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/barlow_twins/network.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/text_recognizer/networks/barlow_twins/network.py b/text_recognizer/networks/barlow_twins/network.py
new file mode 100644
index 0000000..874e570
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/network.py
@@ -0,0 +1,19 @@
+"""Barlow Twins network."""
+from typing import Type
+
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+
+class BarlowTwins(nn.Module):
+ def __init__(self, encoder: Type[nn.Module], projector: Type[nn.Module]) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.projector = projector
+
+ def forward(self, x: Tensor) -> Tensor:
+ z = self.encoder(x)
+ z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1)
+ z_p = self.projector(z_e)
+ return z_p
+