summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-21 20:03:10 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-21 20:03:10 +0100
commitaac452a2dc008338cb543549652da293c14b6b4e (patch)
tree6d018841e28f22eee5289f6cc59c28100a9d023d /text_recognizer
parenta3a40c9c0118039460d5c9fba6a74edc0cdba106 (diff)
Refactor EMNIST dataset
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/datasets/base_data_module.py69
-rw-r--r--text_recognizer/datasets/download_utils.py73
-rw-r--r--text_recognizer/datasets/emnist.py194
-rw-r--r--text_recognizer/datasets/emnist_dataset.py131
-rw-r--r--text_recognizer/datasets/emnist_essentials.json1
5 files changed, 336 insertions, 132 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
new file mode 100644
index 0000000..09a0a43
--- /dev/null
+++ b/text_recognizer/datasets/base_data_module.py
@@ -0,0 +1,69 @@
+"""Base lightning DataModule class."""
+from pathlib import Path
+from typing import Dict
+
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+
+
+def load_and_print_info(data_module_class: type) -> None:
+ """Load EMNISTLines and prints info."""
+ dataset = data_module_class()
+ dataset.prepare_data()
+ dataset.setup()
+ print(dataset)
+
+
+class BaseDataModule(pl.LightningDataModule):
+ """Base PyTorch Lightning DataModule."""
+
+ def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
+ super().__init__()
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+
+ # Placeholders for subclasses.
+ self.dims = None
+ self.output_dims = None
+ self.mapping = None
+
+ @classmethod
+ def data_dirname(cls) -> Path:
+ """Return the path to the base data directory."""
+ return Path(__file__).resolve().parents[2] / "data"
+
+ def config(self) -> Dict:
+ """Return important settings of the dataset."""
+ return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping}
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ pass
+
+ def setup(self, stage: Any = None) -> None:
+ """Split into train, val, test, and set dims.
+
+ Should assign `torch Dataset` objects to self.data_train, self.data_val, and
+ optionally self.data_test.
+
+ Args:
+ stage (Any): Variable to set splits.
+
+ """
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+
+
+ def train_dataloader(self) -> DataLoader:
+ """Retun DataLoader for train data."""
+ return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+
+ def val_dataloader(self) -> DataLoader:
+ """Return DataLoader for val data."""
+ return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+
+ def test_dataloader(self) -> DataLoader:
+ """Return DataLoader for val data."""
+ return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+
diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py
new file mode 100644
index 0000000..7a2cab8
--- /dev/null
+++ b/text_recognizer/datasets/download_utils.py
@@ -0,0 +1,73 @@
+"""Util functions for downloading datasets."""
+import hashlib
+from pathlib import Path
+from typing import Dict, List, Optional
+from urllib.request import urlretrieve
+
+from loguru import logger
+from tqdm import tqdm
+
+
+def _compute_sha256(filename: Path) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with filename.open(mode="rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def _download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
+ """Downloads dataset using a metadata file.
+
+ Args:
+ metadata (Dict): A metadata file of the dataset.
+ dl_dir (Path): Download directory for the dataset.
+
+ Returns:
+ Optional[Path]: Returns filename if dataset is downloaded, None if it already
+ exists.
+
+ Raises:
+ ValueError: If the SHA-256 value is not the same between the dataset and
+ the metadata file.
+
+ """
+ dl_dir.mkdir(parents=True, exist_ok=True)
+ filename = dl_dir / metadata["filename"]
+ if filename.exists():
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ _download_url(metadata["url"], filename)
+ logger.info("Computing the SHA-256...")
+ sha256 = _compute_sha256(filename)
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
+ return filename
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
new file mode 100644
index 0000000..e99dbfd
--- /dev/null
+++ b/text_recognizer/datasets/emnist.py
@@ -0,0 +1,194 @@
+"""EMNIST dataset: downloads it from FSDL aws url if not present."""
+from pathlib import Path
+from typing import Sequence, Tuple
+import json
+import os
+import shutil
+import zipfile
+
+import h5py
+import numpy as np
+from loguru import logger
+import toml
+import torch
+from torch.utils.data import random_split
+from torchvision import transforms
+
+from text_recognizer.datasets.base_dataset import BaseDataset
+from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info
+from text_recognizer.datasets.download_utils import download_dataset
+
+
+SEED = 4711
+NUM_SPECIAL_TOKENS = 4
+SAMPLE_TO_BALANCE = True
+
+RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
+PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json"
+
+
+class EMNIST(BaseDataModule):
+ """
+ "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
+ and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
+ From https://www.nist.gov/itl/iad/image-group/emnist-dataset
+
+ The data split we will use is
+ EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
+ """
+
+ def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None:
+ super().__init__(batch_size, num_workers)
+ if not ESSENTIALS_FILENAME.exists():
+ _download_and_process_emnist()
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ self.train_fraction = train_fraction
+ self.mapping = list(essentials["characters"])
+ self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+ self.transform = transforms.Compose([transforms.ToTensor()])
+ self.dims = (1, *essentials["input_shape"])
+ self.output_dims = (1,)
+
+ def prepare_data(self) -> None:
+ if not PROCESSED_DATA_FILENAME.exists():
+ _download_and_process_emnist()
+
+ def setup(self, stage: str = None) -> None:
+ if stage == "fit" or stage is None:
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ data = f["x_train"][:]
+ targets = f["y_train"][:]
+
+ dataset_train = BaseDataset(data, targets, transform=self.transform)
+ train_size = int(self.train_fraction * len(dataset_train))
+ val_size = len(dataset_train) - train_size
+ self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator())
+
+ if stage == "test" or stage is None:
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ data = f["x_test"][:]
+ targets = f["y_test"][:]
+ self.data_test = BaseDataset(data, targets, transform=self.transform)
+
+
+ def __repr__(self) -> str:
+ basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ datum, target = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n"
+ f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n"
+ )
+
+ return basic + data
+
+
+def _download_and_process_emnist() -> None:
+ metadata = toml.load(METADATA_FILENAME)
+ download_dataset(metadata, DL_DATA_DIRNAME)
+ _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
+
+
+def _process_raw_dataset(filename: str, dirname: Path) -> None:
+ logger.info("Unzipping EMNIST...")
+ curdir = os.getcwd()
+ os.chdir(dirname)
+ content = zipfile.ZipFile(filename, "r")
+ content.extract("matlab/emnist-byclass.mat")
+
+ from scipy.io import loadmat
+
+ logger.info("Loading training data from .mat file")
+ data = loadmat("matlab/emnist-byclass.mat")
+ x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+
+ if SAMPLE_TO_BALANCE:
+ logger.info("Balancing classes to reduce amount of data")
+ x_train, y_train = _sample_to_balance(x_train, y_train)
+ x_test, y_test = _sample_to_balance(x_test, y_test)
+
+
+ logger.info("Saving to HDF5 in a compressed format...")
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
+ f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
+ f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
+ f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
+ f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
+
+ logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
+ characters = _augment_emnist_characters(mapping.values())
+ essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
+
+ with ESSENTIALS_FILENAME.open(mode="w") as f:
+ json.dump(essentials, f)
+
+ logger.info("Cleaning up...")
+ shutil.rmtree("matlab")
+ os.chdir(curdir)
+
+
+def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Balances the dataset by taking the mean number of instances per class."""
+ np.random.seed(SEED)
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ indices = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(indices, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled= y[indices]
+ return x_sampled, y_sampled
+
+
+def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
+ """Augment the mapping with extra symbols."""
+ # Extra characters from the IAM dataset.
+ iam_characters = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # Also add special tokens for:
+ # - CTC blank token at index 0
+ # - Start token at index 1
+ # - End token at index 2
+ # - Padding token at index 3
+ # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this!
+ return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
+
+
+if __name__ == "__main__":
+ load_print_info(EMNIST)
diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py
deleted file mode 100644
index 9884fdf..0000000
--- a/text_recognizer/datasets/emnist_dataset.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""
-
-import json
-from pathlib import Path
-from typing import Callable, Optional, Tuple, Union
-
-from loguru import logger
-import numpy as np
-from PIL import Image
-import torch
-from torch import Tensor
-from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.transforms import Transpose
-from text_recognizer.datasets.util import DATA_DIRNAME
-
-
-class EmnistDataset(Dataset):
- """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
-
- def __init__(
- self,
- pad_token: str = None,
- train: bool = False,
- sample_to_balance: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- seed: int = 4711,
- ) -> None:
- """Loads the dataset and the mappings.
-
- Args:
- pad_token (str): The pad token symbol. Defaults to _.
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
- subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
- transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
- target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
- seed (int): Seed number. Defaults to 4711.
-
- """
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- pad_token=pad_token,
- )
-
- self.sample_to_balance = sample_to_balance
-
- # Have to transpose the emnist characters, ToTensor norms input between [0,1].
- if transform is None:
- self.transform = Compose([Transpose(), ToTensor()])
-
- self.target_transform = None
-
- self.seed = seed
-
- def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
- """Fetches samples from the dataset.
-
- Args:
- index (Union[int, Tensor]): The indices of the samples to fetch.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target tuple.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return (
- "EMNIST Dataset\n"
- f"Num classes: {self.num_classes}\n"
- f"Input shape: {self.input_shape}\n"
- f"Mapping: {self.mapper.mapping}\n"
- )
-
- def _sample_to_balance(self) -> None:
- """Because the dataset is not balanced, we take at most the mean number of instances per class."""
- np.random.seed(self.seed)
- x = self._data
- y = self._targets
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_indices = []
- for label in np.unique(y.flatten()):
- inds = np.where(y == label)[0]
- sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
- all_sampled_indices.append(sampled_indices)
- indices = np.concatenate(all_sampled_indices)
- x_sampled = x[indices]
- y_sampled = y[indices]
- self._data = x_sampled
- self._targets = y_sampled
-
- def load_or_generate_data(self) -> None:
- """Fetch the EMNIST dataset."""
- dataset = EMNIST(
- root=DATA_DIRNAME,
- split="byclass",
- train=self.train,
- download=False,
- transform=None,
- target_transform=None,
- )
-
- self._data = dataset.data
- self._targets = dataset.targets
-
- if self.sample_to_balance:
- self._sample_to_balance()
-
- if self.subsample_fraction is not None:
- self._subsample()
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
deleted file mode 100644
index 2a0648a..0000000
--- a/text_recognizer/datasets/emnist_essentials.json
+++ /dev/null
@@ -1 +0,0 @@
-{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]}