diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-09 23:31:31 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-09 23:31:31 +0200 |
commit | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (patch) | |
tree | 1c0e0898cb8b66faff9e5d410aa1f82d13542f68 /src/text_recognizer | |
parent | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (diff) |
Created an abstract Dataset class for common methods.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 16 | ||||
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 124 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 228 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 56 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_dataset.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 68 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 70 | ||||
-rw-r--r-- | src/text_recognizer/datasets/sentence_generator.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 125 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/ctc.py | 2 |
11 files changed, 338 insertions, 361 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index ede4541..a3af9b1 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,10 +1,5 @@ """Dataset modules.""" -from .emnist_dataset import ( - DATA_DIRNAME, - EmnistDataset, - EmnistMapper, - ESSENTIALS_FILENAME, -) +from .emnist_dataset import EmnistDataset, Transpose from .emnist_lines_dataset import ( construct_image_from_string, EmnistLinesDataset, @@ -13,7 +8,14 @@ from .emnist_lines_dataset import ( from .iam_dataset import IamDataset from .iam_lines_dataset import IamLinesDataset from .iam_paragraphs_dataset import IamParagraphsDataset -from .util import _download_raw_dataset, compute_sha256, download_url, Transpose +from .util import ( + _download_raw_dataset, + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, + ESSENTIALS_FILENAME, +) __all__ = [ "_download_raw_dataset", diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py new file mode 100644 index 0000000..f328a0f --- /dev/null +++ b/src/text_recognizer/datasets/dataset.py @@ -0,0 +1,124 @@ +"""Abstract dataset class.""" +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils import data +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.util import EmnistMapper + + +class Dataset(data.Dataset): + """Abstract class for with common methods for all datasets.""" + + def __init__( + self, + train: bool, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + """Initialization of Dataset class. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + self.train = train + self.split = "train" if self.train else "test" + + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + + self._mapper = EmnistMapper() + self._input_shape = self._mapper.input_shape + self._output_shape = self._mapper._num_classes + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = transform + if self.transform is None: + self.transform = ToTensor() + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self._data = None + self._targets = None + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self.data = self.data[:num_subsample] + self.targets = self.targets[:num_subsample] + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + raise NotImplementedError + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Raises: + NotImplementedError: If the method is not implemented in child class. + + """ + raise NotImplementedError + + def __repr__(self) -> str: + """Returns information about the dataset.""" + raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 0715aae..81268fb 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -2,139 +2,26 @@ import json from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, Optional, Tuple, Union from loguru import logger import numpy as np from PIL import Image import torch from torch import Tensor -from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, Normalize, ToTensor +from torchvision.transforms import Compose, ToTensor -from text_recognizer.datasets.util import Transpose +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import DATA_DIRNAME -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" -def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: - """Extract and saves EMNIST essentials.""" - labels = emnsit_dataset.classes - labels.sort() - mapping = [(i, str(label)) for i, label in enumerate(labels)] - essentials = { - "mapping": mapping, - "input_shape": tuple(emnsit_dataset[0][0].shape[:]), - } - logger.info("Saving emnist essentials...") - with open(ESSENTIALS_FILENAME, "w") as f: - json.dump(essentials, f) - - -def download_emnist() -> None: - """Download the EMNIST dataset via the PyTorch class.""" - logger.info(f"Data directory is: {DATA_DIRNAME}") - dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) - save_emnist_essentials(dataset) - - -class EmnistMapper: - """Mapper between network output to Emnist character.""" - - def __init__(self) -> None: - """Loads the emnist essentials file with the mapping and input shape.""" - self.essentials = self._load_emnist_essentials() - # Load dataset infromation. - self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) - self._inverse_mapping = {v: k for k, v in self.mapping.items()} - self._num_classes = len(self.mapping) - self._input_shape = self.essentials["input_shape"] - - def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: - """Maps the token to emnist character or character index. - - If the token is an integer (index), the method will return the Emnist character corresponding to that index. - If the token is a str (Emnist character), the method will return the corresponding index for that character. - - Args: - token (Union[str, int, np.uint8]): Eihter a string or index (integer). - - Returns: - Union[str, int]: The mapping result. - - Raises: - KeyError: If the index or string does not exist in the mapping. - - """ - if (isinstance(token, np.uint8) or isinstance(token, int)) and int( - token - ) in self.mapping: - return self.mapping[int(token)] - elif isinstance(token, str) and token in self._inverse_mapping: - return self._inverse_mapping[token] - else: - raise KeyError(f"Token {token} does not exist in the mappings.") - - @property - def mapping(self) -> Dict: - """Returns the mapping between index and character.""" - return self._mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the mapping between character and index.""" - return self._inverse_mapping - - @property - def num_classes(self) -> int: - """Returns the number of classes in the dataset.""" - return self._num_classes - - @property - def input_shape(self) -> List[int]: - """Returns the input shape of the Emnist characters.""" - return self._input_shape - - def _load_emnist_essentials(self) -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - def _augment_emnist_mapping(self, mapping: Dict) -> Dict: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol - extra_symbols.append("_") - - max_key = max(mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - return {**mapping, **extra_mapping} + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) class EmnistDataset(Dataset): @@ -159,70 +46,33 @@ class EmnistDataset(Dataset): target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. seed (int): Seed number. Defaults to 4711. - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - """ + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) - self.train = train self.sample_to_balance = sample_to_balance - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self.transform = transform - if self.transform is None: + # Have to transpose the emnist characters, ToTensor norms input between [0,1]. + if transform is None: self.transform = Compose([Transpose(), ToTensor()]) + # The EMNIST dataset is already casted to tensors. self.target_transform = target_transform - self.seed = seed - - self._mapper = EmnistMapper() - self._input_shape = self._mapper.input_shape - self.num_classes = self._mapper.num_classes - - # Load dataset. - self._data, self._targets = self.load_emnist_dataset() - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) + self.seed = seed def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches samples from the dataset. Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. + index (Union[int, Tensor]): The indices of the samples to fetch. Returns: - Tuple[torch.Tensor, torch.Tensor]: Data target tuple. + Tuple[Tensor, Tensor]: Data target tuple. """ if torch.is_tensor(index): @@ -248,13 +98,11 @@ class EmnistDataset(Dataset): f"Mapping: {self.mapper.mapping}\n" ) - def _sample_to_balance( - self, data: Tensor, targets: Tensor - ) -> Tuple[np.ndarray, np.ndarray]: + def _sample_to_balance(self) -> None: """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(self.seed) - x = data - y = targets + x = self._data + y = self._targets num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_indices = [] for label in np.unique(y.flatten()): @@ -264,22 +112,10 @@ class EmnistDataset(Dataset): indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] y_sampled = y[indices] - data = x_sampled - targets = y_sampled - return data, targets - - def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: - """Subsamples the dataset to the specified fraction.""" - x = data - y = targets - num_samples = int(x.shape[0] * self.subsample_fraction) - x_sampled = x[:num_samples] - y_sampled = y[:num_samples] - self.data = x_sampled - self.targets = y_sampled - return data, targets + self._data = x_sampled + self._targets = y_sampled - def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]: + def load_or_generate_data(self) -> None: """Fetch the EMNIST dataset.""" dataset = EMNIST( root=DATA_DIRNAME, @@ -290,13 +126,11 @@ class EmnistDataset(Dataset): target_transform=None, ) - data = dataset.data - targets = dataset.targets + self._data = dataset.data + self._targets = dataset.targets if self.sample_to_balance: - data, targets = self._sample_to_balance(data, targets) + self._sample_to_balance() if self.subsample_fraction is not None: - data, targets = self._subsample(data, targets) - - return data, targets + self._subsample() diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 656131a..8fa77cd 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,17 +9,16 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor -from text_recognizer.datasets import ( +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose +from text_recognizer.datasets.sentence_generator import SentenceGenerator +from text_recognizer.datasets.util import ( DATA_DIRNAME, - EmnistDataset, EmnistMapper, ESSENTIALS_FILENAME, ) -from text_recognizer.datasets.sentence_generator import SentenceGenerator -from text_recognizer.datasets.util import Transpose from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -52,18 +51,11 @@ class EmnistLinesDataset(Dataset): seed (int): Seed number. Defaults to 4711. """ - self.train = train - - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + super().__init__( + train=train, transform=transform, target_transform=target_transform, + ) # Extract dataset information. - self._mapper = EmnistMapper() self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes @@ -75,24 +67,12 @@ class EmnistLinesDataset(Dataset): self.input_shape[0], self.input_shape[1] * self.max_length, ) - self.output_shape = (self.max_length, self.num_classes) + self._output_shape = (self.max_length, self.num_classes) self.seed = seed # Placeholders for the dataset. - self.data = None - self.target = None - - # Load dataset. - self._load_or_generate_data() - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) + self._data = None + self._target = None def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. @@ -132,16 +112,6 @@ class EmnistLinesDataset(Dataset): ) @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -151,7 +121,7 @@ class EmnistLinesDataset(Dataset): filename = "test_" + filename return DATA_DIRNAME / filename - def _load_or_generate_data(self) -> None: + def load_or_generate_data(self) -> None: """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" np.random.seed(self.seed) @@ -163,8 +133,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self.data = f["data"][:] - self.targets = f["targets"][:] + self._data = f["data"][:] + self._targets = f["targets"][:] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py index 5e47350..f4a869d 100644 --- a/src/text_recognizer/datasets/iam_dataset.py +++ b/src/text_recognizer/datasets/iam_dataset.py @@ -7,10 +7,8 @@ from boltons.cacheutils import cachedproperty import defusedxml.ElementTree as ET from loguru import logger import toml -from torch.utils.data import Dataset -from text_recognizer.datasets import DATA_DIRNAME -from text_recognizer.datasets.util import _download_raw_dataset +from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" @@ -20,7 +18,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. -class IamDataset(Dataset): +class IamDataset: """IAM dataset. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 477f500..4a74b2b 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -5,11 +5,15 @@ import h5py from loguru import logger import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper -from text_recognizer.datasets.util import compute_sha256, download_url +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" @@ -29,47 +33,26 @@ class IamLinesDataset(Dataset): transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: - self.train = train - self.split = "train" if self.train else "test" - self._mapper = EmnistMapper() - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor - - self.subsample_fraction = subsample_fraction - self.data = None - self.targets = None - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) @property def input_shape(self) -> Tuple: """Input shape of the data.""" - return self.data.shape[1:] + return self.data.shape[1:] if self.data is not None else None @property def output_shape(self) -> Tuple: """Output shape of the data.""" - return self.targets.shape[1:] + (self.num_classes,) - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) def load_or_generate_data(self) -> None: """Load or generate dataset data.""" @@ -78,19 +61,10 @@ class IamLinesDataset(Dataset): logger.info("Downloading IAM lines...") download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.data = f[f"x_{self.split}"][:] - self.targets = f[f"y_{self.split}"][:] + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] self._subsample() - def _subsample(self) -> None: - """Only a fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - - num_samples = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_samples] - self.targets = self.targets[:num_samples] - def __repr__(self) -> str: """Print info about the dataset.""" return ( diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index d65b346..4b34bd1 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -8,13 +8,17 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor from text_recognizer import util -from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.dataset import Dataset from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import compute_sha256, download_url +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" @@ -28,11 +32,7 @@ SEED = 4711 class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text. - - TODO: __getitem__, __len__, get_data_target_from_id - - """ + """IAM Paragraphs dataset for paragraphs of handwritten text.""" def __init__( self, @@ -41,34 +41,20 @@ class IamParagraphsDataset(Dataset): transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: - + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) # Load Iam dataset. self.iam_dataset = IamDataset() - self.train = train - self.split = "train" if self.train else "test" self.num_classes = 3 self._input_shape = (256, 256) self._output_shape = self._input_shape + (self.num_classes,) - self.subsample_fraction = subsample_fraction - - # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor - - self._data = None - self._targets = None self._ids = None - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. @@ -94,26 +80,6 @@ class IamParagraphsDataset(Dataset): return data, targets @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property def ids(self) -> Tensor: """Ids of the dataset.""" return self._ids @@ -201,14 +167,6 @@ class IamParagraphsDataset(Dataset): logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") return crop_dims - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_subsample] - self.targets = self.targets[:num_subsample] - def __repr__(self) -> str: """Return info about the dataset.""" return ( diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py index ee86bd4..dd76652 100644 --- a/src/text_recognizer/datasets/sentence_generator.py +++ b/src/text_recognizer/datasets/sentence_generator.py @@ -9,7 +9,7 @@ import nltk from nltk.corpus.reader.util import ConcatenatedCorpusView import numpy as np -from text_recognizer.datasets import DATA_DIRNAME +from text_recognizer.datasets.util import DATA_DIRNAME NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index dd16bed..3acf5db 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,6 +1,7 @@ """Util functions for datasets.""" import hashlib import importlib +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Type, Union @@ -11,15 +12,129 @@ from loguru import logger import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import EMNIST from tqdm import tqdm +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) +def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) + + +def download_emnist() -> None: + """Download the EMNIST dataset via the PyTorch class.""" + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) + + +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__(self) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.essentials = self._load_emnist_essentials() + # Load dataset infromation. + self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8]): Eihter a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if (isinstance(token, np.uint8) or isinstance(token, int)) and int( + token + ) in self.mapping: + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} def compute_sha256(filename: Union[Path, str]) -> str: diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 153e19a..d23fe56 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -140,6 +140,7 @@ class Model(ABC): if not self.data_prepared: # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) + train_dataset.load_or_generate_data() # Set input shape. self._input_shape = train_dataset.input_shape @@ -156,6 +157,7 @@ class Model(ABC): # Load test dataset. self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) + self.test_dataset.load_or_generate_data() # Set the flag to true to disable ability to load data agian. self.data_prepared = True diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index fc0d21d..72f18b8 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -5,7 +5,7 @@ from einops import rearrange import torch from torch import Tensor -from text_recognizer.datasets import EmnistMapper +from text_recognizer.datasets.util import EmnistMapper def greedy_decoder( |