summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/data
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_data_module.py29
-rw-r--r--text_recognizer/data/emnist.py22
-rw-r--r--text_recognizer/data/emnist_lines.py35
-rw-r--r--text_recognizer/data/iam.py6
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py33
-rw-r--r--text_recognizer/data/iam_lines.py22
-rw-r--r--text_recognizer/data/iam_paragraphs.py32
7 files changed, 79 insertions, 100 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index de5628f..18b1996 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict
+from typing import Any, Dict, Tuple
import attr
-import pytorch_lightning as LightningDataModule
+from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.base_dataset import BaseDataset
+
def load_and_print_info(data_module_class: type) -> None:
"""Load dataset and print dataset information."""
@@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
-
def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self) -> None:
- # Placeholders for subclasses.
- self.dims = None
- self.output_dims = None
- self.mapping = None
+ batch_size: int = attr.ib(default=16)
+ num_workers: int = attr.ib(default=0)
+
+ # Placeholders
+ data_train: BaseDataset = attr.ib(init=False, default=None)
+ data_val: BaseDataset = attr.ib(init=False, default=None)
+ data_test: BaseDataset = attr.ib(init=False, default=None)
+ dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ mapping: Any = attr.ib(init=False, default=None)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule):
stage (Any): Variable to set splits.
"""
- self.data_train = None
- self.data_val = None
- self.data_test = None
+ pass
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 824b947..d51a42a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,9 +3,10 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Tuple
import zipfile
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+@attr.s(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
@@ -44,18 +46,12 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(
- self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.train_fraction = train_fraction
- self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()
- self.data_train = None
- self.data_val = None
- self.data_test = None
- self.transform = T.Compose([T.ToTensor()])
- self.dims = (1, *self.input_shape)
- self.output_dims = (1,)
+ train_fraction: float = attr.ib()
+ transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
+
+ def __attrs_post_init__(self) -> None:
+ self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
+ self.dims = (1, *input_shape)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 9650198..4747508 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -3,6 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, Tuple
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+@attr.s(auto_attribs=True)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
- def __init__(
- self,
- augment: bool = True,
- batch_size: int = 128,
- num_workers: int = 0,
- max_length: int = 32,
- min_overlap: float = 0.0,
- max_overlap: float = 0.33,
- num_train: int = 10_000,
- num_val: int = 2_000,
- num_test: int = 2_000,
- ) -> None:
- super().__init__(batch_size, num_workers)
-
- self.augment = augment
- self.max_length = max_length
- self.min_overlap = min_overlap
- self.max_overlap = max_overlap
- self.num_train = num_train
- self.num_val = num_val
- self.num_test = num_test
+ augment: bool = attr.ib(default=True)
+ max_length: int = attr.ib(default=128)
+ min_overlap: float = attr.ib(default=0.0)
+ max_overlap: float = attr.ib(default=0.33)
+ num_train: int = attr.ib(default=10_000)
+ num_val: int = attr.ib(default=2_000)
+ num_test: int = attr.ib(default=2_000)
+ emnist: EMNIST = attr.ib(init=False, default=None)
+ def __attrs_post_init__(self) -> None:
self.emnist = EMNIST()
self.mapping = self.emnist.mapping
@@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule):
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
@property
def data_filename(self) -> Path:
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 261c8d3..3982c4f 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -5,6 +5,7 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
+import attr
from boltons.cacheutils import cachedproperty
from loguru import logger
import toml
@@ -22,6 +23,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
+@attr.s(auto_attribs=True)
class IAM(BaseDataModule):
"""
"The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
@@ -35,9 +37,7 @@ class IAM(BaseDataModule):
The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
"""
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
- super().__init__(batch_size, num_workers)
- self.metadata = toml.load(METADATA_FILENAME)
+ metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))
def prepare_data(self) -> None:
if self.xml_filenames:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0a30a42..886e37e 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,4 +1,7 @@
"""IAM original and sythetic dataset class."""
+from typing import Dict, List
+
+import attr
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_dataset import BaseDataset
@@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
+ train_fraction: float = attr.ib()
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
@@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.mapping = self.iam_paragraphs.mapping
self.inverse_mapping = self.iam_paragraphs.inverse_mapping
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
-
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
self.iam_paragraphs.prepare_data()
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 9c78a22..e45e5c8 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,8 +7,9 @@ dataset.
import json
from pathlib import Path
import random
-from typing import List, Sequence, Tuple
+from typing import Dict, List, Sequence, Tuple
+import attr
from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
@@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+@attr.s(auto_attribs=True)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- def __init__(
- self,
- augment: bool = True,
- fraction: float = 0.8,
- batch_size: int = 128,
- num_workers: int = 0,
- ) -> None:
- # TODO: add transforms
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.fraction = fraction
+ augment: bool = attr.ib(default=True)
+ fraction: float = attr.ib(default=0.8)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping()
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (89, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index fe60e99..445b788 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,6 +3,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
@@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
MAX_LABEL_LENGTH = 682
+@attr.s(auto_attribs=True)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.word_pieces = word_pieces
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
- if word_pieces:
+ if self.word_pieces:
self.mapping = WordPieceMapping()
self.train_fraction = train_fraction
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (MAX_LABEL_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Create data for training/testing."""
@@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
),
T.ColorJitter(brightness=(0.8, 1.6)),
T.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else: