summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/barlow_twins/network.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/text_recognizer/networks/barlow_twins/network.py b/text_recognizer/networks/barlow_twins/network.py
new file mode 100644
index 0000000..874e570
--- /dev/null
+++ b/text_recognizer/networks/barlow_twins/network.py
@@ -0,0 +1,19 @@
+"""Barlow Twins network."""
+from typing import Type
+
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+
+class BarlowTwins(nn.Module):
+ def __init__(self, encoder: Type[nn.Module], projector: Type[nn.Module]) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.projector = projector
+
+ def forward(self, x: Tensor) -> Tensor:
+ z = self.encoder(x)
+ z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1)
+ z_p = self.projector(z_e)
+ return z_p
+