diff options
-rw-r--r-- | text_recognizer/data/base_data_module.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 16 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 27 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/mappings/emnist.py | 22 |
7 files changed, 43 insertions, 35 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 6306cf8..77d15e5 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from text_recognizer.data.base_dataset import BaseDataset -from text_recognizer.data.mappings.base import AbstractMapping +from text_recognizer.data.mappings import AbstractMapping T = TypeVar("T") diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 675683a..4ceb818 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -32,13 +32,6 @@ class BaseDataset(Dataset): self.targets = targets self.transform = transform self.target_transform = target_transform - - def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" - super().__init__() - - def __attrs_post_init__(self) -> None: - """Post init constructor.""" if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") self.transform = self._load_transform(self.transform) diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 1b1381a..ea27984 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -102,22 +102,6 @@ class EMNIST(BaseDataModule): return basic + data -def emnist_mapping( - extra_symbols: Optional[Set[str]] = None, -) -> Tuple[List, Dict[str, int], List[int]]: - """Return the EMNIST mapping.""" - if not ESSENTIALS_FILENAME.exists(): - download_and_process_emnist() - with ESSENTIALS_FILENAME.open() as f: - essentials = json.load(f) - mapping = list(essentials["characters"]) - if extra_symbols is not None: - mapping += extra_symbols - inverse_mapping = {v: k for k, v in enumerate(mapping)} - input_shape = essentials["input_shape"] - return mapping, inverse_mapping, input_shape - - def download_and_process_emnist() -> None: """Downloads and preprocesses EMNIST dataset.""" metadata = toml.load(METADATA_FILENAME) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 5f38f14..c23dec6 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -5,7 +5,7 @@ dataset. """ import json from pathlib import Path -from typing import List, Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple, Type from loguru import logger as log import numpy as np @@ -19,7 +19,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings.emnist import EmnistMapping +from text_recognizer.data.mappings import AbstractMapping, EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils import image_utils @@ -37,8 +37,27 @@ MAX_WORD_PIECE_LENGTH = 72 class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - def __init__(self) -> None: - super().__init__() + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) self.output_dims = (MAX_LABEL_LENGTH, 1) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index dde505d..9c75129 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -16,7 +16,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings.emnist import EmnistMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 2e7762c..1dc517d 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -24,7 +24,7 @@ from text_recognizer.data.iam_paragraphs import ( NEW_LINE_TOKEN, resize_image, ) -from text_recognizer.data.mappings.emnist import EmnistMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py index 51e4677..ecd862e 100644 --- a/text_recognizer/data/mappings/emnist.py +++ b/text_recognizer/data/mappings/emnist.py @@ -1,12 +1,15 @@ """Emnist mapping.""" -from typing import List, Optional, Sequence, Union +import json +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Union, Tuple import torch from torch import Tensor -from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.mappings.base import AbstractMapping +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" + class EmnistMapping(AbstractMapping): """Mapping for EMNIST labels.""" @@ -15,13 +18,22 @@ class EmnistMapping(AbstractMapping): self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True ) -> None: self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None - self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( - self.extra_symbols - ) + self.mapping, self.inverse_mapping, self.input_size = self._load_mapping() if lower: self._to_lower() super().__init__(self.input_size, self.mapping, self.inverse_mapping) + def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]: + """Return the EMNIST mapping.""" + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + mapping = list(essentials["characters"]) + if self.extra_symbols is not None: + mapping += self.extra_symbols + inverse_mapping = {v: k for k, v in enumerate(mapping)} + input_shape = essentials["input_shape"] + return mapping, inverse_mapping, input_shape + def _to_lower(self) -> None: """Converts mapping to lowercase letters only.""" |