summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_data_module.py46
-rw-r--r--text_recognizer/data/base_dataset.py19
-rw-r--r--text_recognizer/data/emnist.py6
-rw-r--r--text_recognizer/data/emnist_lines.py49
-rw-r--r--text_recognizer/data/iam.py7
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py27
-rw-r--r--text_recognizer/data/iam_lines.py10
-rw-r--r--text_recognizer/data/iam_paragraphs.py11
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py5
9 files changed, 115 insertions, 65 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index a0c8416..6306cf8 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,6 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-from attrs import define, field
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
@@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@define(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __attrs_post_init__(self) -> None:
- """Pre init constructor."""
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ ) -> None:
super().__init__()
-
- mapping: Type[AbstractMapping] = field()
- transform: Optional[Callable] = field(default=None)
- test_transform: Optional[Callable] = field(default=None)
- target_transform: Optional[Callable] = field(default=None)
- train_fraction: float = field(default=0.8)
- batch_size: int = field(default=16)
- num_workers: int = field(default=0)
- pin_memory: bool = field(default=True)
-
- # Placeholders
- data_train: BaseDataset = field(init=False, default=None)
- data_val: BaseDataset = field(init=False, default=None)
- data_test: BaseDataset = field(init=False, default=None)
- dims: Tuple[int, ...] = field(init=False, default=None)
- output_dims: Tuple[int, ...] = field(init=False, default=None)
+ self.mapping = mapping
+ self.transform = transform
+ self.test_transform = test_transform
+ self.target_transform = target_transform
+ self.train_fraction = train_fraction
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+
+ # Placeholders
+ self.data_train: BaseDataset
+ self.data_val: BaseDataset
+ self.data_test: BaseDataset
+ self.dims: Tuple[int, ...]
+ self.output_dims: Tuple[int, ...]
@classmethod
def data_dirname(cls: T) -> Path:
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index c57cbcc..675683a 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,7 +1,6 @@
"""Base PyTorch Dataset class."""
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
-from attrs import define, field
import torch
from torch import Tensor
from torch.utils.data import Dataset
@@ -9,7 +8,6 @@ from torch.utils.data import Dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@define
class BaseDataset(Dataset):
r"""Base Dataset class that processes data and targets through optional transfroms.
@@ -21,10 +19,19 @@ class BaseDataset(Dataset):
target transforms.
"""
- data: Union[Sequence, Tensor] = field()
- targets: Union[Sequence, Tensor] = field()
- transform: Union[Optional[Callable], str] = field(default=None)
- target_transform: Union[Optional[Callable], str] = field(default=None)
+ def __init__(
+ self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Union[Optional[Callable], str],
+ target_transform: Union[Optional[Callable], str],
+ ) -> None:
+ super().__init__()
+
+ self.data = data
+ self.targets = targets
+ self.transform = transform
+ self.target_transform = target_transform
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 94882bf..1b1381a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -6,7 +6,6 @@ import shutil
from typing import Dict, List, Optional, Sequence, Set, Tuple
import zipfile
-from attrs import define
import h5py
from loguru import logger as log
import numpy as np
@@ -35,7 +34,6 @@ ESSENTIALS_FILENAME = (
)
-@define(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
@@ -48,8 +46,8 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self) -> None:
+ super().__init__()
self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 43d55b9..062257d 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,9 +1,8 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import DefaultDict, List, Tuple
+from typing import Callable, DefaultDict, List, Optional, Tuple, Type
-from attrs import define, field
import h5py
from loguru import logger as log
import numpy as np
@@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import (
)
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
@@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@define(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- max_length: int = field(default=128)
- min_overlap: float = field(default=0.0)
- max_overlap: float = field(default=0.33)
- num_train: int = field(default=10_000)
- num_val: int = field(default=2_000)
- num_test: int = field(default=2_000)
- emnist: EMNIST = field(init=False, default=None)
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ max_length: int = 128,
+ min_overlap: float = 0.0,
+ max_overlap: float = 0.33,
+ num_train: int = 10_000,
+ num_val: int = 2_000,
+ num_test: int = 2_000,
+ ) -> None:
+ super().__init__(
+ mapping,
+ transform,
+ test_transform,
+ target_transform,
+ train_fraction,
+ batch_size,
+ num_workers,
+ pin_memory,
+ )
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
- self.emnist = EMNIST(mapping=self.mapping)
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_train = num_train
+ self.num_val = num_val
+ self.num_test = num_test
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ IMAGE_X_PADDING
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 8166863..a4d1d21 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -9,7 +9,6 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
-from attrs import define, field
from boltons.cacheutils import cachedproperty
from loguru import logger as log
import toml
@@ -27,7 +26,6 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
-@define(auto_attribs=True)
class IAM(BaseDataModule):
r"""The IAM Lines dataset.
@@ -44,7 +42,10 @@ class IAM(BaseDataModule):
contributed to one set only.
"""
- metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME))
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.metadata: Dict = toml.load(METADATA_FILENAME)
def prepare_data(self) -> None:
"""Prepares the IAM dataset."""
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 52c10c3..0c16181 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,19 +1,38 @@
"""IAM original and sythetic dataset class."""
-from attrs import define, field
+from typing import Callable, Optional, Type
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@define(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
"""A dataset with synthetic and real handwritten paragraph."""
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ ) -> None:
+ super().__init__(
+ mapping,
+ transform,
+ test_transform,
+ target_transform,
+ train_fraction,
+ batch_size,
+ num_workers,
+ pin_memory,
+ )
self.iam_paragraphs = IAMParagraphs(
mapping=self.mapping,
batch_size=self.batch_size,
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 34cf605..5f38f14 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,7 +7,6 @@ import json
from pathlib import Path
from typing import List, Sequence, Tuple
-from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
@@ -35,14 +34,13 @@ MAX_LABEL_LENGTH = 89
MAX_WORD_PIECE_LENGTH = 72
-@define(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- dims: Tuple[int, int, int] = field(
- init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
- )
- output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
+ def __init__(self) -> None:
+ super().__init__()
+ self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.output_dims = (MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index b605bbc..dde505d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,7 +3,6 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
-from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
@@ -33,15 +32,13 @@ MAX_LABEL_LENGTH = 682
MAX_WORD_PIECE_LENGTH = 451
-@define(auto_attribs=True, repr=False)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- # Placeholders
- dims: Tuple[int, int, int] = field(
- init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
- )
- output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
+ def __init__(self) -> None:
+ super().__init__()
+ self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.output_dims = (MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Create data for training/testing."""
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 7143951..2e7762c 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,7 +2,6 @@
import random
from typing import Any, List, Sequence, Tuple
-from attrs import define
from loguru import logger as log
import numpy as np
from PIL import Image
@@ -34,10 +33,12 @@ PROCESSED_DATA_DIRNAME = (
)
-@define(auto_attribs=True, repr=False)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
+ def __init__(self) -> None:
+ super().__init__()
+
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
if PROCESSED_DATA_DIRNAME.exists():