summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py38
1 files changed, 22 insertions, 16 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index b0617f5..656131a 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -9,8 +9,8 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import DataLoader, Dataset
-from torchvision.transforms import Compose, Normalize, ToTensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
from text_recognizer.datasets import (
DATA_DIRNAME,
@@ -20,6 +20,7 @@ from text_recognizer.datasets import (
)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
from text_recognizer.datasets.util import Transpose
+from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset):
self.transform = transform
if self.transform is None:
- self.transform = Compose([ToTensor()])
+ self.transform = ToTensor()
self.target_transform = target_transform
if self.target_transform is None:
@@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset):
# Extract dataset information.
self._mapper = EmnistMapper()
- self.input_shape = self._mapper.input_shape
+ self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
self.max_length = max_length
self.min_overlap = min_overlap
self.max_overlap = max_overlap
self.num_samples = num_samples
- self.input_shape = (
+ self._input_shape = (
self.input_shape[0],
self.input_shape[1] * self.max_length,
)
@@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset):
# Load dataset.
self._load_or_generate_data()
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset):
return data, targets
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistLinesDataset"
-
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
@@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset):
return self._mapper
@property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
if self.train:
filename = "train_" + filename
else:
- filename = "val_" + filename
+ filename = "test_" + filename
return DATA_DIRNAME / filename
def _load_or_generate_data(self) -> None:
@@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset):
)
targets = convert_strings_to_categorical_labels(
- targets, self.emnist.inverse_mapping
+ targets, emnist.inverse_mapping
)
f.create_dataset("data", data=data, dtype="u1", compression="lzf")
@@ -322,13 +328,13 @@ def create_datasets(
min_overlap: float = 0,
max_overlap: float = 0.33,
num_train: int = 10000,
- num_val: int = 1000,
+ num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_val = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_val]
- num_samples = [num_train, num_val]
+ emnist_test = EmnistDataset(train=False, sample_to_balance=True)
+ datasets = [emnist_train, emnist_test]
+ num_samples = [num_train, num_test]
for num, train, dataset in zip(num_samples, [True, False], datasets):
emnist_lines = EmnistLinesDataset(
train=train,