diff options
Diffstat (limited to 'text_recognizer/data/stems/paragraph.py')
-rw-r--r-- | text_recognizer/data/stems/paragraph.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/text_recognizer/data/stems/paragraph.py b/text_recognizer/data/stems/paragraph.py new file mode 100644 index 0000000..39e1e59 --- /dev/null +++ b/text_recognizer/data/stems/paragraph.py @@ -0,0 +1,66 @@ +"""Iam paragraph stem class.""" +import torchvision.transforms as T + +import text_recognizer.metadata.iam_paragraphs as metadata +from text_recognizer.data.stems.image import ImageStem + + +IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH +IMAGE_SHAPE = metadata.IMAGE_SHAPE + +MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH + + +class ParagraphStem(ImageStem): + """A stem for handling images that contain a paragraph of text.""" + + def __init__( + self, + augment=False, + color_jitter_kwargs=None, + random_affine_kwargs=None, + random_perspective_kwargs=None, + gaussian_blur_kwargs=None, + sharpness_kwargs=None, + ): + super().__init__() + + if not augment: + self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)]) + else: + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 3, + "shear": 6, + "scale": (0.95, 1), + "interpolation": T.InterpolationMode.BILINEAR, + } + if random_perspective_kwargs is None: + random_perspective_kwargs = { + "distortion_scale": 0.2, + "p": 0.5, + "interpolation": T.InterpolationMode.BILINEAR, + } + if gaussian_blur_kwargs is None: + gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} + if sharpness_kwargs is None: + sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} + + self.pil_transform = T.Compose( + [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomCrop( + size=IMAGE_SHAPE, + padding=None, + pad_if_needed=True, + fill=0, + padding_mode="constant", + ), + T.RandomAffine(**random_affine_kwargs), + T.RandomPerspective(**random_perspective_kwargs), + T.GaussianBlur(**gaussian_blur_kwargs), + T.RandomAdjustSharpness(**sharpness_kwargs), + ] + ) |