summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/barlow.py
blob: 78683cb1617b317269697aaa6ed2edf03dd5eba3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Augmentations for training Barlow Twins."""
from omegaconf.dictconfig import DictConfig
from torch import Tensor

from text_recognizer.data.transforms.load_transform import load_transform


class BarlowTransform:
    """Applies two different transforms to input data."""

    def __init__(self, prim: DictConfig, bis: DictConfig) -> None:
        self.prim = load_transform(prim)
        self.bis = load_transform(bis)

    def __call__(self, data: Tensor) -> Tensor:
        """Applies two different augmentation on the input."""
        x_prim = self.prim(data)
        x_bis = self.bis(data)
        return x_prim, x_bis