summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_dataset.py16
-rw-r--r--text_recognizer/data/emnist.py2
-rw-r--r--text_recognizer/data/emnist_lines.py8
-rw-r--r--text_recognizer/data/iam.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py15
-rw-r--r--text_recognizer/data/iam_lines.py7
-rw-r--r--text_recognizer/data/iam_paragraphs.py16
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py56
-rw-r--r--text_recognizer/data/stems/__init__.py0
-rw-r--r--text_recognizer/data/stems/image.py18
-rw-r--r--text_recognizer/data/stems/line.py86
-rw-r--r--text_recognizer/data/stems/paragraph.py66
12 files changed, 235 insertions, 57 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 4ceb818..b840bc8 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -5,8 +5,6 @@ import torch
from torch import Tensor
from torch.utils.data import Dataset
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
-
class BaseDataset(Dataset):
r"""Base Dataset class that processes data and targets through optional transfroms.
@@ -23,8 +21,8 @@ class BaseDataset(Dataset):
self,
data: Union[Sequence, Tensor],
targets: Union[Sequence, Tensor],
- transform: Union[Optional[Callable], str],
- target_transform: Union[Optional[Callable], str],
+ transform: Callable,
+ target_transform: Callable,
) -> None:
super().__init__()
@@ -34,16 +32,6 @@ class BaseDataset(Dataset):
self.target_transform = target_transform
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
- self.transform = self._load_transform(self.transform)
- self.target_transform = self._load_transform(self.target_transform)
-
- @staticmethod
- def _load_transform(
- transform: Union[Optional[Callable], str]
- ) -> Optional[Callable]:
- if isinstance(transform, str):
- return load_transform_from_file(transform)
- return transform
def __len__(self) -> int:
"""Return the length of the dataset."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 9c5727f..143705e 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -15,7 +15,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.download_utils import download_dataset
-from text_recognizer.metadata import emnist as metadata
+import text_recognizer.metadata.emnist as metadata
class EMNIST(BaseDataModule):
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 63c9f22..88aac0d 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -13,9 +13,9 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
+from text_recognizer.data.stems.line import LineStem
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
-from text_recognizer.metadata import emnist_lines as metadata
+import text_recognizer.metadata.emnist_lines as metadata
class EMNISTLines(BaseDataModule):
@@ -250,6 +250,6 @@ def _create_dataset_of_images(
def generate_emnist_lines() -> None:
"""Generates a synthetic handwritten dataset and displays info."""
- transform = load_transform_from_file("transform/emnist_lines.yaml")
- test_transform = load_transform_from_file("test_transform/default.yaml")
+ transform = LineStem(augment=False)
+ test_transform = LineStem(augment=False)
load_and_print_info(EMNISTLines(transform=transform, test_transform=test_transform))
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index c20b50b..2ce1e9c 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -15,7 +15,7 @@ from loguru import logger as log
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.utils.download_utils import download_dataset
-from text_recognizer.metadata import iam as metadata
+import text_recognizer.metadata.iam as metadata
class IAM(BaseDataModule):
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 3ec8221..658626c 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -6,8 +6,10 @@ from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+from text_recognizer.data.transforms.pad import Pad
from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
+from text_recognizer.data.stems.paragraph import ParagraphStem
+import text_recognizer.metadata.iam_paragraphs as metadata
class IAMExtendedParagraphs(BaseDataModule):
@@ -104,8 +106,13 @@ class IAMExtendedParagraphs(BaseDataModule):
def show_dataset_info() -> None:
"""Displays Iam extended dataset information."""
- transform = load_transform_from_file("transform/paragraphs.yaml")
- test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml")
+ transform = ParagraphStem(augment=False)
+ test_transform = ParagraphStem(augment=False)
+ target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3)
load_and_print_info(
- IAMExtendedParagraphs(transform=transform, test_transform=test_transform)
+ IAMExtendedParagraphs(
+ transform=transform,
+ test_transform=test_transform,
+ target_transform=target_transform,
+ )
)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 3bb189c..e60d1ba 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -21,8 +21,9 @@ from text_recognizer.data.base_dataset import (
from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
+from text_recognizer.data.stems.line import IamLinesStem
from text_recognizer.data.utils import image_utils
-from text_recognizer.metadata import iam_lines as metadata
+import text_recognizer.metadata.iam_lines as metadata
ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -227,6 +228,6 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
def generate_iam_lines() -> None:
"""Displays Iam Lines dataset statistics."""
- transform = load_transform_from_file("transform/lines.yaml")
- test_transform = load_transform_from_file("test_transform/lines.yaml")
+ transform = IamLinesStem()
+ test_transform = IamLinesStem()
load_and_print_info(IAMLines(transform=transform, test_transform=test_transform))
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index eec1b1f..fe1f15c 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -16,9 +16,10 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.iam import IAM
+from text_recognizer.data.transforms.pad import Pad
from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
-from text_recognizer.metadata import iam_paragraphs as metadata
+from text_recognizer.data.stems.paragraph import ParagraphStem
+import text_recognizer.metadata.iam_paragraphs as metadata
class IAMParagraphs(BaseDataModule):
@@ -294,8 +295,13 @@ def _num_lines(label: str) -> int:
def create_iam_paragraphs() -> None:
"""Loads and displays dataset statistics."""
- transform = load_transform_from_file("transform/paragraphs.yaml")
- test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml")
+ transform = ParagraphStem()
+ test_transform = ParagraphStem()
+ target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3)
load_and_print_info(
- IAMParagraphs(transform=transform, test_transform=test_transform)
+ IAMParagraphs(
+ transform=transform,
+ test_transform=test_transform,
+ target_transform=target_transform,
+ )
)
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 52ed398..91fda4a 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -9,25 +9,20 @@ from PIL import Image
from text_recognizer.data.base_data_module import load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.iam import IAM
-from text_recognizer.data.iam_lines import (
- line_crops_and_labels,
- load_line_crops_and_labels,
- save_images_and_labels,
-)
from text_recognizer.data.iam_paragraphs import (
- IMAGE_SCALE_FACTOR,
- NEW_LINE_TOKEN,
IAMParagraphs,
get_dataset_properties,
resize_image,
)
-from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
-from text_recognizer.metadata import shared as metadata
-
-PROCESSED_DATA_DIRNAME = (
- metadata.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs"
+from text_recognizer.data.iam_lines import (
+ line_crops_and_labels,
+ load_line_crops_and_labels,
+ save_images_and_labels,
)
+from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.stems.paragraph import ParagraphStem
+from text_recognizer.data.transforms.pad import Pad
+import text_recognizer.metadata.iam_synthetic_paragraphs as metadata
class IAMSyntheticParagraphs(IAMParagraphs):
@@ -57,26 +52,32 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
- if PROCESSED_DATA_DIRNAME.exists():
+ if metadata.PROCESSED_DATA_DIRNAME.exists():
return
log.info("Preparing IAM lines for synthetic paragraphs dataset.")
log.info("Cropping IAM line regions and loading labels.")
- iam = IAM(mapping=EmnistMapping(extra_symbols=(NEW_LINE_TOKEN,)))
+ iam = IAM(mapping=EmnistMapping(extra_symbols=(metadata.NEW_LINE_TOKEN,)))
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
- crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train]
- crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test]
+ crops_train = [
+ resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops_train
+ ]
+ crops_test = [
+ resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops_test
+ ]
- log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
+ log.info(f"Saving images and labels at {metadata.PROCESSED_DATA_DIRNAME}")
save_images_and_labels(
- crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
+ crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME
+ )
+ save_images_and_labels(
+ crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME
)
- save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)
def setup(self, stage: str = None) -> None:
"""Loading synthetic dataset."""
@@ -85,7 +86,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
if stage == "fit" or stage is None:
line_crops, line_labels = load_line_crops_and_labels(
- "train", PROCESSED_DATA_DIRNAME
+ "train", metadata.PROCESSED_DATA_DIRNAME
)
data, paragraphs_labels = generate_synthetic_paragraphs(
line_crops=line_crops, line_labels=line_labels
@@ -157,7 +158,7 @@ def generate_synthetic_paragraphs(
paragraphs_crops, paragraphs_labels = [], []
for paragraph_indices in batched_indices_list:
- paragraph_label = NEW_LINE_TOKEN.join(
+ paragraph_label = metadata.NEW_LINE_TOKEN.join(
[line_labels[i] for i in paragraph_indices]
)
if len(paragraph_label) > paragraphs_properties["label_length"]["max"]:
@@ -236,8 +237,13 @@ def generate_random_batches(
def create_synthetic_iam_paragraphs() -> None:
"""Creates and prints IAM Synthetic Paragraphs dataset."""
- transform = load_transform_from_file("transform/paragraphs.yaml")
- test_transform = load_transform_from_file("test_transform/paragraphs.yaml")
+ transform = ParagraphStem()
+ test_transform = ParagraphStem()
+ target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3)
load_and_print_info(
- IAMSyntheticParagraphs(transform=transform, test_transform=test_transform)
+ IAMSyntheticParagraphs(
+ transform=transform,
+ test_transform=test_transform,
+ target_transform=target_transform,
+ )
)
diff --git a/text_recognizer/data/stems/__init__.py b/text_recognizer/data/stems/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/data/stems/__init__.py
diff --git a/text_recognizer/data/stems/image.py b/text_recognizer/data/stems/image.py
new file mode 100644
index 0000000..f04b3a0
--- /dev/null
+++ b/text_recognizer/data/stems/image.py
@@ -0,0 +1,18 @@
+from PIL import Image
+import torch
+from torch import Tensor
+import torchvision.transforms as T
+
+
+class ImageStem:
+ def __init__(self) -> None:
+ self.pil_transform = T.Compose([])
+ self.pil_to_tensor = T.ToTensor()
+ self.torch_transform = torch.nn.Sequential()
+
+ def __call__(self, img: Image) -> Tensor:
+ img = self.pil_transform(img)
+ img = self.pil_to_tensor(img)
+ with torch.no_grad():
+ img = self.torch_transform(img)
+ return img
diff --git a/text_recognizer/data/stems/line.py b/text_recognizer/data/stems/line.py
new file mode 100644
index 0000000..2fe1a2c
--- /dev/null
+++ b/text_recognizer/data/stems/line.py
@@ -0,0 +1,86 @@
+import random
+
+from PIL import Image
+import torchvision.transforms as T
+
+import text_recognizer.metadata.iam_lines as metadata
+from text_recognizer.data.stems.image import ImageStem
+
+
+class LineStem(ImageStem):
+ """A stem for handling images containing a line of text."""
+
+ def __init__(
+ self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None
+ ):
+ super().__init__()
+ if color_jitter_kwargs is None:
+ color_jitter_kwargs = {"brightness": (0.5, 1)}
+ if random_affine_kwargs is None:
+ random_affine_kwargs = {
+ "degrees": 3,
+ "translate": (0, 0.05),
+ "scale": (0.4, 1.1),
+ "shear": (-40, 50),
+ "interpolation": T.InterpolationMode.BILINEAR,
+ "fill": 0,
+ }
+
+ if augment:
+ self.pil_transforms = T.Compose(
+ [
+ T.ColorJitter(**color_jitter_kwargs),
+ T.RandomAffine(**random_affine_kwargs),
+ ]
+ )
+
+
+class IamLinesStem(ImageStem):
+ """A stem for handling images containing lines of text from the IAMLines dataset."""
+
+ def __init__(
+ self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None
+ ):
+ super().__init__()
+
+ def embed_crop(crop, augment=augment):
+ # crop is PIL.image of dtype="L" (so values range from 0 -> 255)
+ image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
+
+ # Resize crop
+ crop_width, crop_height = crop.size
+ new_crop_height = metadata.IMAGE_HEIGHT
+ new_crop_width = int(new_crop_height * (crop_width / crop_height))
+ if augment:
+ # Add random stretching
+ new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
+ new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
+ crop_resized = crop.resize(
+ (new_crop_width, new_crop_height), resample=Image.BILINEAR
+ )
+
+ # Embed in the image
+ x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
+ y = metadata.IMAGE_HEIGHT - new_crop_height
+
+ image.paste(crop_resized, (x, y))
+
+ return image
+
+ if color_jitter_kwargs is None:
+ color_jitter_kwargs = {"brightness": (0.8, 1.6)}
+ if random_affine_kwargs is None:
+ random_affine_kwargs = {
+ "degrees": 1,
+ "shear": (-30, 20),
+ "interpolation": T.InterpolationMode.BILINEAR,
+ "fill": 0,
+ }
+
+ pil_transform_list = [T.Lambda(embed_crop)]
+ if augment:
+ pil_transform_list += [
+ T.ColorJitter(**color_jitter_kwargs),
+ T.RandomAffine(**random_affine_kwargs),
+ ]
+ self.pil_transform = T.Compose(pil_transform_list)
diff --git a/text_recognizer/data/stems/paragraph.py b/text_recognizer/data/stems/paragraph.py
new file mode 100644
index 0000000..39e1e59
--- /dev/null
+++ b/text_recognizer/data/stems/paragraph.py
@@ -0,0 +1,66 @@
+"""Iam paragraph stem class."""
+import torchvision.transforms as T
+
+import text_recognizer.metadata.iam_paragraphs as metadata
+from text_recognizer.data.stems.image import ImageStem
+
+
+IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
+IMAGE_SHAPE = metadata.IMAGE_SHAPE
+
+MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
+
+
+class ParagraphStem(ImageStem):
+ """A stem for handling images that contain a paragraph of text."""
+
+ def __init__(
+ self,
+ augment=False,
+ color_jitter_kwargs=None,
+ random_affine_kwargs=None,
+ random_perspective_kwargs=None,
+ gaussian_blur_kwargs=None,
+ sharpness_kwargs=None,
+ ):
+ super().__init__()
+
+ if not augment:
+ self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)])
+ else:
+ if color_jitter_kwargs is None:
+ color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
+ if random_affine_kwargs is None:
+ random_affine_kwargs = {
+ "degrees": 3,
+ "shear": 6,
+ "scale": (0.95, 1),
+ "interpolation": T.InterpolationMode.BILINEAR,
+ }
+ if random_perspective_kwargs is None:
+ random_perspective_kwargs = {
+ "distortion_scale": 0.2,
+ "p": 0.5,
+ "interpolation": T.InterpolationMode.BILINEAR,
+ }
+ if gaussian_blur_kwargs is None:
+ gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
+ if sharpness_kwargs is None:
+ sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
+
+ self.pil_transform = T.Compose(
+ [
+ T.ColorJitter(**color_jitter_kwargs),
+ T.RandomCrop(
+ size=IMAGE_SHAPE,
+ padding=None,
+ pad_if_needed=True,
+ fill=0,
+ padding_mode="constant",
+ ),
+ T.RandomAffine(**random_affine_kwargs),
+ T.RandomPerspective(**random_perspective_kwargs),
+ T.GaussianBlur(**gaussian_blur_kwargs),
+ T.RandomAdjustSharpness(**sharpness_kwargs),
+ ]
+ )