summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/datasets
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py5
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py35
-rw-r--r--src/text_recognizer/datasets/transforms.py27
3 files changed, 48 insertions, 19 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index a8901d6..9884fdf 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -22,6 +22,7 @@ class EmnistDataset(Dataset):
def __init__(
self,
+ pad_token: str = None,
train: bool = False,
sample_to_balance: bool = False,
subsample_fraction: float = None,
@@ -32,6 +33,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
+ pad_token (str): The pad token symbol. Defaults to _.
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
@@ -45,6 +47,7 @@ class EmnistDataset(Dataset):
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ pad_token=pad_token,
)
self.sample_to_balance = sample_to_balance
@@ -53,6 +56,8 @@ class EmnistDataset(Dataset):
if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
+ self.target_transform = None
+
self.seed = seed
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6091da8..6871492 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -4,6 +4,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
+import click
import h5py
from loguru import logger
import numpy as np
@@ -58,13 +59,15 @@ class EmnistLinesDataset(Dataset):
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
"""
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
transform=transform,
target_transform=target_transform,
subsample_fraction=subsample_fraction,
init_token=init_token,
- pad_token=pad_token,
+ pad_token=self.pad_token,
eos_token=eos_token,
)
@@ -127,11 +130,7 @@ class EmnistLinesDataset(Dataset):
@property
def data_filename(self) -> Path:
"""Path to the h5 file."""
- filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
- if self.train:
- filename = "train_" + filename
- else:
- filename = "test_" + filename
+ filename = "train.pt" if self.train else "test.pt"
return DATA_DIRNAME / filename
def load_or_generate_data(self) -> None:
@@ -147,8 +146,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self._data = f["data"][:]
- self._targets = f["targets"][:]
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
@@ -157,7 +156,9 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
emnist.load_or_generate_data()
samples_by_character = get_samples_by_character(
@@ -308,6 +309,18 @@ def convert_strings_to_categorical_labels(
return np.array([[mapping[c] for c in label] for label in labels])
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
def create_datasets(
max_length: int = 34,
min_overlap: float = 0,
@@ -326,3 +339,7 @@ def create_datasets(
num_samples=num,
)
emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index c058972..8deac7f 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,7 +3,7 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torchvision.transforms import Compose, ToTensor
+from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
from text_recognizer.datasets.util import EmnistMapper
@@ -19,28 +19,35 @@ class Transpose:
class AddTokens:
"""Adds start of sequence and end of sequence tokens to target tensor."""
- def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None:
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
self.pad_value = self.emnist_mapper(self.pad_token)
- self.sos_value = self.emnist_mapper(self.init_token)
self.eos_value = self.emnist_mapper(self.eos_token)
def __call__(self, target: Tensor) -> Tensor:
"""Adds a sos token to the begining and a eos token to the end of a target sequence."""
dtype, device = target.dtype, target.device
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
# Find the where padding starts.
pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
target[pad_index] = self.eos_value
- target = torch.cat([sos, target], dim=0)
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
return target