diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:05:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:05:24 +0200 |
commit | 38f546f0b86fc0dc89863b00c5ee8c6685394ef2 (patch) | |
tree | deaf5501c8e687e235add858ad92e068df7d3615 /text_recognizer/data/transforms/barlow.py | |
parent | 8291a87c64f9a5f18caec82201bea15579b49730 (diff) |
Add custom transforms
Diffstat (limited to 'text_recognizer/data/transforms/barlow.py')
-rw-r--r-- | text_recognizer/data/transforms/barlow.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/barlow.py b/text_recognizer/data/transforms/barlow.py new file mode 100644 index 0000000..78683cb --- /dev/null +++ b/text_recognizer/data/transforms/barlow.py @@ -0,0 +1,19 @@ +"""Augmentations for training Barlow Twins.""" +from omegaconf.dictconfig import DictConfig +from torch import Tensor + +from text_recognizer.data.transforms.load_transform import load_transform + + +class BarlowTransform: + """Applies two different transforms to input data.""" + + def __init__(self, prim: DictConfig, bis: DictConfig) -> None: + self.prim = load_transform(prim) + self.bis = load_transform(bis) + + def __call__(self, data: Tensor) -> Tensor: + """Applies two different augmentation on the input.""" + x_prim = self.prim(data) + x_bis = self.bis(data) + return x_prim, x_bis |