summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/barlow.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/transforms/barlow.py')
-rw-r--r--text_recognizer/data/transforms/barlow.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/barlow.py b/text_recognizer/data/transforms/barlow.py
new file mode 100644
index 0000000..78683cb
--- /dev/null
+++ b/text_recognizer/data/transforms/barlow.py
@@ -0,0 +1,19 @@
+"""Augmentations for training Barlow Twins."""
+from omegaconf.dictconfig import DictConfig
+from torch import Tensor
+
+from text_recognizer.data.transforms.load_transform import load_transform
+
+
+class BarlowTransform:
+ """Applies two different transforms to input data."""
+
+ def __init__(self, prim: DictConfig, bis: DictConfig) -> None:
+ self.prim = load_transform(prim)
+ self.bis = load_transform(bis)
+
+ def __call__(self, data: Tensor) -> Tensor:
+ """Applies two different augmentation on the input."""
+ x_prim = self.prim(data)
+ x_bis = self.bis(data)
+ return x_prim, x_bis