summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/barlow.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
commit38f546f0b86fc0dc89863b00c5ee8c6685394ef2 (patch)
treedeaf5501c8e687e235add858ad92e068df7d3615 /text_recognizer/data/transforms/barlow.py
parent8291a87c64f9a5f18caec82201bea15579b49730 (diff)
Add custom transforms
Diffstat (limited to 'text_recognizer/data/transforms/barlow.py')
-rw-r--r--text_recognizer/data/transforms/barlow.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/barlow.py b/text_recognizer/data/transforms/barlow.py
new file mode 100644
index 0000000..78683cb
--- /dev/null
+++ b/text_recognizer/data/transforms/barlow.py
@@ -0,0 +1,19 @@
+"""Augmentations for training Barlow Twins."""
+from omegaconf.dictconfig import DictConfig
+from torch import Tensor
+
+from text_recognizer.data.transforms.load_transform import load_transform
+
+
+class BarlowTransform:
+ """Applies two different transforms to input data."""
+
+ def __init__(self, prim: DictConfig, bis: DictConfig) -> None:
+ self.prim = load_transform(prim)
+ self.bis = load_transform(bis)
+
+ def __call__(self, data: Tensor) -> Tensor:
+ """Applies two different augmentation on the input."""
+ x_prim = self.prim(data)
+ x_bis = self.bis(data)
+ return x_prim, x_bis