summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/datasets
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py4
-rw-r--r--src/text_recognizer/datasets/dataset.py22
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py6
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py51
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py6
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py7
-rw-r--r--src/text_recognizer/datasets/transforms.py40
-rw-r--r--src/text_recognizer/datasets/util.py31
8 files changed, 133 insertions, 34 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index a3af9b1..d8372e3 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,5 +1,5 @@
"""Dataset modules."""
-from .emnist_dataset import EmnistDataset, Transpose
+from .emnist_dataset import EmnistDataset
from .emnist_lines_dataset import (
construct_image_from_string,
EmnistLinesDataset,
@@ -8,6 +8,7 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
+from .transforms import AddTokens, Transpose
from .util import (
_download_raw_dataset,
compute_sha256,
@@ -19,6 +20,7 @@ from .util import (
__all__ = [
"_download_raw_dataset",
+ "AddTokens",
"compute_sha256",
"construct_image_from_string",
"DATA_DIRNAME",
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 05520e5..2de7f09 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -18,6 +18,9 @@ class Dataset(data.Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Initialization of Dataset class.
@@ -26,12 +29,14 @@ class Dataset(data.Dataset):
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
-
self.train = train
self.split = "train" if self.train else "test"
@@ -40,19 +45,18 @@ class Dataset(data.Dataset):
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ )
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ self.transform = transform if transform is not None else ToTensor()
+ self.target_transform = (
+ target_transform if target_transform is not None else torch.tensor
+ )
self._data = None
self._targets = None
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index d01dcee..9884fdf 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -22,6 +22,7 @@ class EmnistDataset(Dataset):
def __init__(
self,
+ pad_token: str = None,
train: bool = False,
sample_to_balance: bool = False,
subsample_fraction: float = None,
@@ -32,6 +33,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
+ pad_token (str): The pad token symbol. Defaults to _.
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
@@ -45,6 +47,7 @@ class EmnistDataset(Dataset):
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ pad_token=pad_token,
)
self.sample_to_balance = sample_to_balance
@@ -53,8 +56,7 @@ class EmnistDataset(Dataset):
if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
- # The EMNIST dataset is already casted to tensors.
- self.target_transform = target_transform
+ self.target_transform = None
self.seed = seed
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6268a01..6871492 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -4,6 +4,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
+import click
import h5py
from loguru import logger
import numpy as np
@@ -37,6 +38,9 @@ class EmnistLinesDataset(Dataset):
max_overlap: float = 0.33,
num_samples: int = 10000,
seed: int = 4711,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Set attributes and loads the dataset.
@@ -50,13 +54,21 @@ class EmnistLinesDataset(Dataset):
max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
num_samples (int): Number of samples to generate. Defaults to 10000.
seed (int): Seed number. Defaults to 4711.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
"""
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
transform=transform,
target_transform=target_transform,
subsample_fraction=subsample_fraction,
+ init_token=init_token,
+ pad_token=self.pad_token,
+ eos_token=eos_token,
)
# Extract dataset information.
@@ -118,11 +130,7 @@ class EmnistLinesDataset(Dataset):
@property
def data_filename(self) -> Path:
"""Path to the h5 file."""
- filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
- if self.train:
- filename = "train_" + filename
- else:
- filename = "test_" + filename
+ filename = "train.pt" if self.train else "test.pt"
return DATA_DIRNAME / filename
def load_or_generate_data(self) -> None:
@@ -138,8 +146,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self._data = f["data"][:]
- self._targets = f["targets"][:]
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
@@ -148,7 +156,10 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
+ emnist.load_or_generate_data()
samples_by_character = get_samples_by_character(
emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
@@ -298,6 +309,18 @@ def convert_strings_to_categorical_labels(
return np.array([[mapping[c] for c in label] for label in labels])
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
def create_datasets(
max_length: int = 34,
min_overlap: float = 0,
@@ -306,17 +329,17 @@ def create_datasets(
num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
- emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_test = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_test]
num_samples = [num_train, num_test]
- for num, train, dataset in zip(num_samples, [True, False], datasets):
+ for num, train in zip(num_samples, [True, False]):
emnist_lines = EmnistLinesDataset(
train=train,
- emnist=dataset,
max_length=max_length,
min_overlap=min_overlap,
max_overlap=max_overlap,
num_samples=num,
)
- emnist_lines._load_or_generate_data()
+ emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 4a74b2b..fdd2fe6 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -32,12 +32,18 @@ class IamLinesDataset(Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
super().__init__(
train=train,
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
)
@property
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index 4b34bd1..c1e8fe2 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None:
@click.option(
"--subsample_fraction",
type=float,
- default=0.0,
+ default=None,
help="The subsampling factor of the dataset.",
)
def main(subsample_fraction: float) -> None:
"""Load dataset and print info."""
+ logger.info("Creating train set...")
+ dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+ logger.info("Creating test set...")
dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
dataset.load_or_generate_data()
print(dataset)
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 17231a8..8deac7f 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,6 +3,9 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
+from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
class Transpose:
@@ -11,3 +14,40 @@ class Transpose:
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
+ return target
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 73968a1..d2df8b5 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -4,6 +4,7 @@ import importlib
import json
import os
from pathlib import Path
+import string
from typing import Callable, Dict, List, Optional, Type, Union
from urllib.request import urlopen, urlretrieve
@@ -26,7 +27,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
mapping = [(i, str(label)) for i, label in enumerate(labels)]
essentials = {
"mapping": mapping,
- "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
}
logger.info("Saving emnist essentials...")
with open(ESSENTIALS_FILENAME, "w") as f:
@@ -43,11 +44,21 @@ def download_emnist() -> None:
class EmnistMapper:
"""Mapper between network output to Emnist character."""
- def __init__(self) -> None:
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ ) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+
self.essentials = self._load_emnist_essentials()
# Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
@@ -103,7 +114,7 @@ class EmnistMapper:
essentials = json.load(f)
return essentials
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
extra_symbols = [
@@ -127,14 +138,20 @@ class EmnistMapper:
]
# padding symbol, and acts as blank symbol as well.
- extra_symbols.append("_")
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
- max_key = max(mapping.keys())
+ max_key = max(self.mapping.keys())
extra_mapping = {}
for i, symbol in enumerate(extra_symbols):
extra_mapping[max_key + 1 + i] = symbol
- return {**mapping, **extra_mapping}
+ self._mapping = {**self.mapping, **extra_mapping}
def compute_sha256(filename: Union[Path, str]) -> str: