summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-09 23:31:31 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-09 23:31:31 +0200
commit2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (patch)
tree1c0e0898cb8b66faff9e5d410aa1f82d13542f68 /src/text_recognizer
parente1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (diff)
Created an abstract Dataset class for common methods.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py16
-rw-r--r--src/text_recognizer/datasets/dataset.py124
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py228
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py56
-rw-r--r--src/text_recognizer/datasets/iam_dataset.py6
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py68
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py70
-rw-r--r--src/text_recognizer/datasets/sentence_generator.py2
-rw-r--r--src/text_recognizer/datasets/util.py125
-rw-r--r--src/text_recognizer/models/base.py2
-rw-r--r--src/text_recognizer/networks/ctc.py2
11 files changed, 338 insertions, 361 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index ede4541..a3af9b1 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,10 +1,5 @@
"""Dataset modules."""
-from .emnist_dataset import (
- DATA_DIRNAME,
- EmnistDataset,
- EmnistMapper,
- ESSENTIALS_FILENAME,
-)
+from .emnist_dataset import EmnistDataset, Transpose
from .emnist_lines_dataset import (
construct_image_from_string,
EmnistLinesDataset,
@@ -13,7 +8,14 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
-from .util import _download_raw_dataset, compute_sha256, download_url, Transpose
+from .util import (
+ _download_raw_dataset,
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+ ESSENTIALS_FILENAME,
+)
__all__ = [
"_download_raw_dataset",
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
new file mode 100644
index 0000000..f328a0f
--- /dev/null
+++ b/src/text_recognizer/datasets/dataset.py
@@ -0,0 +1,124 @@
+"""Abstract dataset class."""
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils import data
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class Dataset(data.Dataset):
+ """Abstract class for with common methods for all datasets."""
+
+ def __init__(
+ self,
+ train: bool,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ """Initialization of Dataset class.
+
+ Args:
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+ target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+ """
+ self.train = train
+ self.split = "train" if self.train else "test"
+
+ if subsample_fraction is not None:
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+ self.subsample_fraction = subsample_fraction
+
+ self._mapper = EmnistMapper()
+ self._input_shape = self._mapper.input_shape
+ self._output_shape = self._mapper._num_classes
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self._data = None
+ self._targets = None
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_subsample]
+ self.targets = self.targets[:num_subsample]
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ raise NotImplementedError
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+ Raises:
+ NotImplementedError: If the method is not implemented in child class.
+
+ """
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 0715aae..81268fb 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -2,139 +2,26 @@
import json
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Callable, Optional, Tuple, Union
from loguru import logger
import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, Normalize, ToTensor
+from torchvision.transforms import Compose, ToTensor
-from text_recognizer.datasets.util import Transpose
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import DATA_DIRNAME
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
-def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(self) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.essentials = self._load_emnist_essentials()
- # Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8]): Eihter a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
- token
- ) in self.mapping:
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol
- extra_symbols.append("_")
-
- max_key = max(mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- return {**mapping, **extra_mapping}
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
class EmnistDataset(Dataset):
@@ -159,70 +46,33 @@ class EmnistDataset(Dataset):
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
seed (int): Seed number. Defaults to 4711.
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
"""
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
- self.train = train
self.sample_to_balance = sample_to_balance
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
-
- self.transform = transform
- if self.transform is None:
+ # Have to transpose the emnist characters, ToTensor norms input between [0,1].
+ if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
+ # The EMNIST dataset is already casted to tensors.
self.target_transform = target_transform
- self.seed = seed
-
- self._mapper = EmnistMapper()
- self._input_shape = self._mapper.input_shape
- self.num_classes = self._mapper.num_classes
-
- # Load dataset.
- self._data, self._targets = self.load_emnist_dataset()
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the inverse mapping from character to index."""
- return self.mapper.inverse_mapping
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ self.seed = seed
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
- index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+ index (Union[int, Tensor]): The indices of the samples to fetch.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target tuple.
+ Tuple[Tensor, Tensor]: Data target tuple.
"""
if torch.is_tensor(index):
@@ -248,13 +98,11 @@ class EmnistDataset(Dataset):
f"Mapping: {self.mapper.mapping}\n"
)
- def _sample_to_balance(
- self, data: Tensor, targets: Tensor
- ) -> Tuple[np.ndarray, np.ndarray]:
+ def _sample_to_balance(self) -> None:
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(self.seed)
- x = data
- y = targets
+ x = self._data
+ y = self._targets
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
@@ -264,22 +112,10 @@ class EmnistDataset(Dataset):
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
y_sampled = y[indices]
- data = x_sampled
- targets = y_sampled
- return data, targets
-
- def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
- """Subsamples the dataset to the specified fraction."""
- x = data
- y = targets
- num_samples = int(x.shape[0] * self.subsample_fraction)
- x_sampled = x[:num_samples]
- y_sampled = y[:num_samples]
- self.data = x_sampled
- self.targets = y_sampled
- return data, targets
+ self._data = x_sampled
+ self._targets = y_sampled
- def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:
+ def load_or_generate_data(self) -> None:
"""Fetch the EMNIST dataset."""
dataset = EMNIST(
root=DATA_DIRNAME,
@@ -290,13 +126,11 @@ class EmnistDataset(Dataset):
target_transform=None,
)
- data = dataset.data
- targets = dataset.targets
+ self._data = dataset.data
+ self._targets = dataset.targets
if self.sample_to_balance:
- data, targets = self._sample_to_balance(data, targets)
+ self._sample_to_balance()
if self.subsample_fraction is not None:
- data, targets = self._subsample(data, targets)
-
- return data, targets
+ self._subsample()
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 656131a..8fa77cd 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -9,17 +9,16 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
-from text_recognizer.datasets import (
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+from text_recognizer.datasets.util import (
DATA_DIRNAME,
- EmnistDataset,
EmnistMapper,
ESSENTIALS_FILENAME,
)
-from text_recognizer.datasets.sentence_generator import SentenceGenerator
-from text_recognizer.datasets.util import Transpose
from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -52,18 +51,11 @@ class EmnistLinesDataset(Dataset):
seed (int): Seed number. Defaults to 4711.
"""
- self.train = train
-
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ super().__init__(
+ train=train, transform=transform, target_transform=target_transform,
+ )
# Extract dataset information.
- self._mapper = EmnistMapper()
self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
@@ -75,24 +67,12 @@ class EmnistLinesDataset(Dataset):
self.input_shape[0],
self.input_shape[1] * self.max_length,
)
- self.output_shape = (self.max_length, self.num_classes)
+ self._output_shape = (self.max_length, self.num_classes)
self.seed = seed
# Placeholders for the dataset.
- self.data = None
- self.target = None
-
- # Load dataset.
- self._load_or_generate_data()
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ self._data = None
+ self._target = None
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
@@ -132,16 +112,6 @@ class EmnistLinesDataset(Dataset):
)
@property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
-
- @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
@@ -151,7 +121,7 @@ class EmnistLinesDataset(Dataset):
filename = "test_" + filename
return DATA_DIRNAME / filename
- def _load_or_generate_data(self) -> None:
+ def load_or_generate_data(self) -> None:
"""Loads the dataset, if it does not exist a new dataset is generated before loading it."""
np.random.seed(self.seed)
@@ -163,8 +133,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self.data = f["data"][:]
- self.targets = f["targets"][:]
+ self._data = f["data"][:]
+ self._targets = f["targets"][:]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py
index 5e47350..f4a869d 100644
--- a/src/text_recognizer/datasets/iam_dataset.py
+++ b/src/text_recognizer/datasets/iam_dataset.py
@@ -7,10 +7,8 @@ from boltons.cacheutils import cachedproperty
import defusedxml.ElementTree as ET
from loguru import logger
import toml
-from torch.utils.data import Dataset
-from text_recognizer.datasets import DATA_DIRNAME
-from text_recognizer.datasets.util import _download_raw_dataset
+from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
@@ -20,7 +18,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
-class IamDataset(Dataset):
+class IamDataset:
"""IAM dataset.
"The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 477f500..4a74b2b 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -5,11 +5,15 @@ import h5py
from loguru import logger
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
-from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
-from text_recognizer.datasets.util import compute_sha256, download_url
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
@@ -29,47 +33,26 @@ class IamLinesDataset(Dataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
- self.train = train
- self.split = "train" if self.train else "test"
- self._mapper = EmnistMapper()
- self.num_classes = self.mapper.num_classes
-
- # Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
-
- self.subsample_fraction = subsample_fraction
- self.data = None
- self.targets = None
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
@property
def input_shape(self) -> Tuple:
"""Input shape of the data."""
- return self.data.shape[1:]
+ return self.data.shape[1:] if self.data is not None else None
@property
def output_shape(self) -> Tuple:
"""Output shape of the data."""
- return self.targets.shape[1:] + (self.num_classes,)
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
def load_or_generate_data(self) -> None:
"""Load or generate dataset data."""
@@ -78,19 +61,10 @@ class IamLinesDataset(Dataset):
logger.info("Downloading IAM lines...")
download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self.data = f[f"x_{self.split}"][:]
- self.targets = f[f"y_{self.split}"][:]
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
self._subsample()
- def _subsample(self) -> None:
- """Only a fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
-
- num_samples = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_samples]
- self.targets = self.targets[:num_samples]
-
def __repr__(self) -> str:
"""Print info about the dataset."""
return (
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index d65b346..4b34bd1 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -8,13 +8,17 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from text_recognizer import util
-from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import compute_sha256, download_url
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
@@ -28,11 +32,7 @@ SEED = 4711
class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text.
-
- TODO: __getitem__, __len__, get_data_target_from_id
-
- """
+ """IAM Paragraphs dataset for paragraphs of handwritten text."""
def __init__(
self,
@@ -41,34 +41,20 @@ class IamParagraphsDataset(Dataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
-
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
# Load Iam dataset.
self.iam_dataset = IamDataset()
- self.train = train
- self.split = "train" if self.train else "test"
self.num_classes = 3
self._input_shape = (256, 256)
self._output_shape = self._input_shape + (self.num_classes,)
- self.subsample_fraction = subsample_fraction
-
- # Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
-
- self._data = None
- self._targets = None
self._ids = None
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
-
def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
@@ -94,26 +80,6 @@ class IamParagraphsDataset(Dataset):
return data, targets
@property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return self._output_shape
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
-
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
def ids(self) -> Tensor:
"""Ids of the dataset."""
return self._ids
@@ -201,14 +167,6 @@ class IamParagraphsDataset(Dataset):
logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
return crop_dims
- def _subsample(self) -> None:
- """Only this fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
- num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_subsample]
- self.targets = self.targets[:num_subsample]
-
def __repr__(self) -> str:
"""Return info about the dataset."""
return (
diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py
index ee86bd4..dd76652 100644
--- a/src/text_recognizer/datasets/sentence_generator.py
+++ b/src/text_recognizer/datasets/sentence_generator.py
@@ -9,7 +9,7 @@ import nltk
from nltk.corpus.reader.util import ConcatenatedCorpusView
import numpy as np
-from text_recognizer.datasets import DATA_DIRNAME
+from text_recognizer.datasets.util import DATA_DIRNAME
NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index dd16bed..3acf5db 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,6 +1,7 @@
"""Util functions for datasets."""
import hashlib
import importlib
+import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union
@@ -11,15 +12,129 @@ from loguru import logger
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
+from torchvision.datasets import EMNIST
from tqdm import tqdm
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
+def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
+
+
+def download_emnist() -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
+
+
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(self) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset infromation.
+ self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
+ token
+ ) in self.mapping:
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol
+ extra_symbols.append("_")
+
+ max_key = max(mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ return {**mapping, **extra_mapping}
def compute_sha256(filename: Union[Path, str]) -> str:
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 153e19a..d23fe56 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -140,6 +140,7 @@ class Model(ABC):
if not self.data_prepared:
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+ train_dataset.load_or_generate_data()
# Set input shape.
self._input_shape = train_dataset.input_shape
@@ -156,6 +157,7 @@ class Model(ABC):
# Load test dataset.
self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+ self.test_dataset.load_or_generate_data()
# Set the flag to true to disable ability to load data agian.
self.data_prepared = True
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index fc0d21d..72f18b8 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -5,7 +5,7 @@ from einops import rearrange
import torch
from torch import Tensor
-from text_recognizer.datasets import EmnistMapper
+from text_recognizer.datasets.util import EmnistMapper
def greedy_decoder(