summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py43
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
2 files changed, 26 insertions, 26 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 96f84e5..49ebad3 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -8,6 +8,7 @@ from loguru import logger
import numpy as np
from PIL import Image
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -183,12 +184,8 @@ class EmnistDataset(Dataset):
self.input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
- # Placeholders
- self.data = None
- self.targets = None
-
# Load dataset.
- self.load_emnist_dataset()
+ self.data, self.targets = self.load_emnist_dataset()
@property
def mapper(self) -> EmnistMapper:
@@ -199,9 +196,7 @@ class EmnistDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
@@ -239,11 +234,13 @@ class EmnistDataset(Dataset):
f"Mapping: {self.mapper.mapping}\n"
)
- def _sample_to_balance(self) -> None:
+ def _sample_to_balance(
+ self, data: Tensor, targets: Tensor
+ ) -> Tuple[np.ndarray, np.ndarray]:
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(self.seed)
- x = self.data
- y = self.targets
+ x = data
+ y = targets
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
@@ -253,20 +250,22 @@ class EmnistDataset(Dataset):
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
y_sampled = y[indices]
- self.data = x_sampled
- self.targets = y_sampled
+ data = x_sampled
+ targets = y_sampled
+ return data, targets
- def _subsample(self) -> None:
+ def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""Subsamples the dataset to the specified fraction."""
- x = self.data
- y = self.targets
+ x = data
+ y = targets
num_samples = int(x.shape[0] * self.subsample_fraction)
x_sampled = x[:num_samples]
y_sampled = y[:num_samples]
self.data = x_sampled
self.targets = y_sampled
+ return data, targets
- def load_emnist_dataset(self) -> None:
+ def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:
"""Fetch the EMNIST dataset."""
dataset = EMNIST(
root=DATA_DIRNAME,
@@ -277,11 +276,13 @@ class EmnistDataset(Dataset):
target_transform=None,
)
- self.data = dataset.data
- self.targets = dataset.targets
+ data = dataset.data
+ targets = dataset.targets
if self.sample_to_balance:
- self._sample_to_balance()
+ data, targets = self._sample_to_balance(data, targets)
if self.subsample_fraction is not None:
- self._subsample()
+ data, targets = self._subsample(data, targets)
+
+ return data, targets
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index d64a991..b0617f5 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -8,6 +8,7 @@ import h5py
from loguru import logger
import numpy as np
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
Args:
- index (Union[int, torch.Tensor]): Either a list or int of indices/index.
+ index (Union[int, Tensor]): Either a list or int of indices/index.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target pair.
+ Tuple[Tensor, Tensor]: Data target pair.
"""
if torch.is_tensor(index):