diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 46 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 19 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 49 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 27 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 11 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 5 |
9 files changed, 115 insertions, 65 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index a0c8416..6306cf8 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_post_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: super().__init__() - - mapping: Type[AbstractMapping] = field() - transform: Optional[Callable] = field(default=None) - test_transform: Optional[Callable] = field(default=None) - target_transform: Optional[Callable] = field(default=None) - train_fraction: float = field(default=0.8) - batch_size: int = field(default=16) - num_workers: int = field(default=0) - pin_memory: bool = field(default=True) - - # Placeholders - data_train: BaseDataset = field(init=False, default=None) - data_val: BaseDataset = field(init=False, default=None) - data_test: BaseDataset = field(init=False, default=None) - dims: Tuple[int, ...] = field(init=False, default=None) - output_dims: Tuple[int, ...] = field(init=False, default=None) + self.mapping = mapping + self.transform = transform + self.test_transform = test_transform + self.target_transform = target_transform + self.train_fraction = train_fraction + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + + # Placeholders + self.data_train: BaseDataset + self.data_val: BaseDataset + self.data_test: BaseDataset + self.dims: Tuple[int, ...] + self.output_dims: Tuple[int, ...] @classmethod def data_dirname(cls: T) -> Path: diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c57cbcc..675683a 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,7 +1,6 @@ """Base PyTorch Dataset class.""" from typing import Callable, Dict, Optional, Sequence, Tuple, Union -from attrs import define, field import torch from torch import Tensor from torch.utils.data import Dataset @@ -9,7 +8,6 @@ from torch.utils.data import Dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file -@define class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -21,10 +19,19 @@ class BaseDataset(Dataset): target transforms. """ - data: Union[Sequence, Tensor] = field() - targets: Union[Sequence, Tensor] = field() - transform: Union[Optional[Callable], str] = field(default=None) - target_transform: Union[Optional[Callable], str] = field(default=None) + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Union[Optional[Callable], str], + target_transform: Union[Optional[Callable], str], + ) -> None: + super().__init__() + + self.data = data + self.targets = targets + self.transform = transform + self.target_transform = target_transform def __attrs_pre_init__(self) -> None: """Pre init constructor.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 94882bf..1b1381a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -6,7 +6,6 @@ import shutil from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile -from attrs import define import h5py from loguru import logger as log import numpy as np @@ -35,7 +34,6 @@ ESSENTIALS_FILENAME = ( ) -@define(auto_attribs=True) class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. @@ -48,8 +46,8 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__(self) -> None: + super().__init__() self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 43d55b9..062257d 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,9 +1,8 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import DefaultDict, List, Tuple +from typing import Callable, DefaultDict, List, Optional, Tuple, Type -from attrs import define, field import h5py from loguru import logger as log import numpy as np @@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator @@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@define(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - max_length: int = field(default=128) - min_overlap: float = field(default=0.0) - max_overlap: float = field(default=0.33) - num_train: int = field(default=10_000) - num_val: int = field(default=2_000) - num_test: int = field(default=2_000) - emnist: EMNIST = field(init=False, default=None) + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + max_length: int = 128, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) - def __attrs_post_init__(self) -> None: - """Post init constructor.""" - self.emnist = EMNIST(mapping=self.mapping) + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 8166863..a4d1d21 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile -from attrs import define, field from boltons.cacheutils import cachedproperty from loguru import logger as log import toml @@ -27,7 +26,6 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. -@define(auto_attribs=True) class IAM(BaseDataModule): r"""The IAM Lines dataset. @@ -44,7 +42,10 @@ class IAM(BaseDataModule): contributed to one set only. """ - metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME)) + def __init__(self) -> None: + super().__init__() + + self.metadata: Dict = toml.load(METADATA_FILENAME) def prepare_data(self) -> None: """Prepares the IAM dataset.""" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 52c10c3..0c16181 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,19 +1,38 @@ """IAM original and sythetic dataset class.""" -from attrs import define, field +from typing import Callable, Optional, Type from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.load_transform import load_transform_from_file -@define(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): """A dataset with synthetic and real handwritten paragraph.""" - def __attrs_post_init__(self) -> None: - """Post init constructor.""" + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) self.iam_paragraphs = IAMParagraphs( mapping=self.mapping, batch_size=self.batch_size, diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 34cf605..5f38f14 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,7 +7,6 @@ import json from pathlib import Path from typing import List, Sequence, Tuple -from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageFile, ImageOps @@ -35,14 +34,13 @@ MAX_LABEL_LENGTH = 89 MAX_WORD_PIECE_LENGTH = 72 -@define(auto_attribs=True, repr=False) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - dims: Tuple[int, int, int] = field( - init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) - ) - output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) + def __init__(self) -> None: + super().__init__() + self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) + self.output_dims = (MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index b605bbc..dde505d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,7 +3,6 @@ import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple -from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageOps @@ -33,15 +32,13 @@ MAX_LABEL_LENGTH = 682 MAX_WORD_PIECE_LENGTH = 451 -@define(auto_attribs=True, repr=False) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - # Placeholders - dims: Tuple[int, int, int] = field( - init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) - ) - output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) + def __init__(self) -> None: + super().__init__() + self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) + self.output_dims = (MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Create data for training/testing.""" diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 7143951..2e7762c 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,7 +2,6 @@ import random from typing import Any, List, Sequence, Tuple -from attrs import define from loguru import logger as log import numpy as np from PIL import Image @@ -34,10 +33,12 @@ PROCESSED_DATA_DIRNAME = ( ) -@define(auto_attribs=True, repr=False) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" + def __init__(self) -> None: + super().__init__() + def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" if PROCESSED_DATA_DIRNAME.exists(): |