diff options
Diffstat (limited to 'src/text_recognizer/datasets')
| -rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 43 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 9 | 
2 files changed, 26 insertions, 26 deletions
| diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 96f84e5..49ebad3 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -8,6 +8,7 @@ from loguru import logger  import numpy as np  from PIL import Image  import torch +from torch import Tensor  from torch.utils.data import DataLoader, Dataset  from torchvision.datasets import EMNIST  from torchvision.transforms import Compose, Normalize, ToTensor @@ -183,12 +184,8 @@ class EmnistDataset(Dataset):          self.input_shape = self._mapper.input_shape          self.num_classes = self._mapper.num_classes -        # Placeholders -        self.data = None -        self.targets = None -          # Load dataset. -        self.load_emnist_dataset() +        self.data, self.targets = self.load_emnist_dataset()      @property      def mapper(self) -> EmnistMapper: @@ -199,9 +196,7 @@ class EmnistDataset(Dataset):          """Returns the length of the dataset."""          return len(self.data) -    def __getitem__( -        self, index: Union[int, torch.Tensor] -    ) -> Tuple[torch.Tensor, torch.Tensor]: +    def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:          """Fetches samples from the dataset.          Args: @@ -239,11 +234,13 @@ class EmnistDataset(Dataset):              f"Mapping: {self.mapper.mapping}\n"          ) -    def _sample_to_balance(self) -> None: +    def _sample_to_balance( +        self, data: Tensor, targets: Tensor +    ) -> Tuple[np.ndarray, np.ndarray]:          """Because the dataset is not balanced, we take at most the mean number of instances per class."""          np.random.seed(self.seed) -        x = self.data -        y = self.targets +        x = data +        y = targets          num_to_sample = int(np.bincount(y.flatten()).mean())          all_sampled_indices = []          for label in np.unique(y.flatten()): @@ -253,20 +250,22 @@ class EmnistDataset(Dataset):          indices = np.concatenate(all_sampled_indices)          x_sampled = x[indices]          y_sampled = y[indices] -        self.data = x_sampled -        self.targets = y_sampled +        data = x_sampled +        targets = y_sampled +        return data, targets -    def _subsample(self) -> None: +    def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:          """Subsamples the dataset to the specified fraction.""" -        x = self.data -        y = self.targets +        x = data +        y = targets          num_samples = int(x.shape[0] * self.subsample_fraction)          x_sampled = x[:num_samples]          y_sampled = y[:num_samples]          self.data = x_sampled          self.targets = y_sampled +        return data, targets -    def load_emnist_dataset(self) -> None: +    def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:          """Fetch the EMNIST dataset."""          dataset = EMNIST(              root=DATA_DIRNAME, @@ -277,11 +276,13 @@ class EmnistDataset(Dataset):              target_transform=None,          ) -        self.data = dataset.data -        self.targets = dataset.targets +        data = dataset.data +        targets = dataset.targets          if self.sample_to_balance: -            self._sample_to_balance() +            data, targets = self._sample_to_balance(data, targets)          if self.subsample_fraction is not None: -            self._subsample() +            data, targets = self._subsample(data, targets) + +        return data, targets diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index d64a991..b0617f5 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,6 +8,7 @@ import h5py  from loguru import logger  import numpy as np  import torch +from torch import Tensor  from torch.utils.data import DataLoader, Dataset  from torchvision.transforms import Compose, Normalize, ToTensor @@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset):          """Returns the length of the dataset."""          return len(self.data) -    def __getitem__( -        self, index: Union[int, torch.Tensor] -    ) -> Tuple[torch.Tensor, torch.Tensor]: +    def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:          """Fetches data, target pair of the dataset for a given and index or indices.          Args: -            index (Union[int, torch.Tensor]): Either a list or int of indices/index. +            index (Union[int, Tensor]): Either a list or int of indices/index.          Returns: -            Tuple[torch.Tensor, torch.Tensor]: Data target pair. +            Tuple[Tensor, Tensor]: Data target pair.          """          if torch.is_tensor(index): |