diff options
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/data_loader.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 258 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 12 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py (renamed from src/text_recognizer/models/util.py) | 2 |
6 files changed, 169 insertions, 128 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 929cfb9..aec5bf9 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,2 @@ """Dataset modules.""" -from .data_loader import fetch_data_loader +from .emnist_dataset import EmnistDataLoader diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py deleted file mode 100644 index fd55934..0000000 --- a/src/text_recognizer/datasets/data_loader.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Data loader collection.""" - -from typing import Dict - -from torch.utils.data import DataLoader - -from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader - - -def fetch_data_loader(data_loader_args: Dict) -> DataLoader: - """Fetches the specified PyTorch data loader.""" - if data_loader_args.pop("name").lower() == "emnist": - return fetch_emnist_data_loader(data_loader_args) - else: - raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f9c8ffa..a17d7a9 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -24,7 +24,7 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def save_emnist_essentials(emnsit_dataset: EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -45,111 +45,163 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def load_emnist_mapping() -> Dict: +def load_emnist_mapping() -> Dict[int, str]: """Load the EMNIST mapping.""" with open(str(ESSENTIALS_FILENAME)) as f: essentials = json.load(f) return dict(essentials["mapping"]) -def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(seed) - x = dataset.data - y = dataset.targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_inds = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_inds.append(sampled_inds) - ind = np.concatenate(all_sampled_inds) - x_sampled = x[ind] - y_sampled = y[ind] - dataset.data = x_sampled - dataset.targets = y_sampled - - -def fetch_emnist_dataset( - split: str, - train: bool, - sample_to_balance: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, -) -> EMNIST: - """Fetch the EMNIST dataset.""" - if transform is None: - transform = Compose([Transpose(), ToTensor()]) - - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=train, - download=False, - transform=transform, - target_transform=target_transform, - ) - - if sample_to_balance and split == "byclass": - _sample_to_balance(dataset) - - return dataset - - -def fetch_emnist_data_loader( - splits: List[str], - sample_to_balance: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[DataLoader]: - """Fetches the EMNIST dataset and return a PyTorch DataLoader. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. - Defaults to False. - transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a - transformed version. E.g, transforms.RandomCrop. Defaults to None. - target_transform (Optional[Callable]): A function/transform that takes in the target and transforms - it. - Defaults to None. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict: A dict containing PyTorch DataLoader(s) with emnist characters. - - """ - data_loaders = {} - - for split in ["train", "val"]: - if split in splits: - - if split == "train": - train = True - else: - train = False - - dataset = fetch_emnist_dataset( - split, train, sample_to_balance, transform, target_transform - ) - - data_loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders +class EmnistDataLoader: + """Class for Emnist DataLoaders.""" + + def __init__( + self, + splits: List[str], + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, + seed: int = 4711, + ) -> None: + """Fetches DataLoaders. + + Args: + splits (List[str]): One or both of the dataset splits "train" and "val". + sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. + Defaults to False. + subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire + dataset will be loaded. + transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a + transformed version. E.g, transforms.RandomCrop. Defaults to None. + target_transform (Optional[Callable]): A function/transform that takes in the target and + transforms it. Defaults to None. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. + seed (int): Seed for sampling. + + """ + self.splits = splits + self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: + assert ( + 0.0 < subsample_fraction < 1.0 + ), " The subsample fraction must be in (0, 1)." + self.subsample_fraction = subsample_fraction + self.transform = transform + self.target_transform = target_transform + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.cuda = cuda + self._data_loaders = self._fetch_emnist_data_loaders() + + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EMNIST" + + def __call__(self, split: str) -> Optional[DataLoader]: + """Returns the `split` DataLoader. + + Args: + split (str): The dataset split, i.e. train or val. + + Returns: + type: A PyTorch DataLoader. + + Raises: + ValueError: If the split does not exist. + + """ + try: + return self._data_loaders[split] + except KeyError: + raise ValueError(f"Split {split} does not exist.") + + def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = dataset.data + y = dataset.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + dataset.data = x_sampled + dataset.targets = y_sampled + + return dataset + + def _subsample(self, dataset: type = EMNIST) -> EMNIST: + """Subsamples the dataset to the specified fraction.""" + x = dataset.data + y = dataset.targets + num_samples = int(x.shape[0] * self.subsample_fraction) + x_sampled = x[:num_samples] + y_sampled = y[:num_samples] + dataset.data = x_sampled + dataset.targets = y_sampled + + return dataset + + def _fetch_emnist_dataset(self, train: bool) -> EMNIST: + """Fetch the EMNIST dataset.""" + if self.transform is None: + transform = Compose([Transpose(), ToTensor()]) + + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=train, + download=False, + transform=transform, + target_transform=self.target_transform, + ) + + if self.sample_to_balance: + dataset = self._sample_to_balance(dataset) + + if self.subsample_fraction is not None: + dataset = self._subsample(dataset) + + return dataset + + def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: + """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" + data_loaders = {} + + for split in ["train", "val"]: + if split in self.splits: + + if split == "train": + train = True + else: + train = False + + dataset = self._fetch_emnist_dataset(train) + + data_loader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 736af7b..0cc531a 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -10,7 +10,6 @@ import torch from torch import nn from torchsummary import summary -from text_recognizer.dataset.data_loader import fetch_data_loader WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -22,6 +21,7 @@ class Model(ABC): self, network_fn: Callable, network_args: Dict, + data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -32,12 +32,13 @@ class Model(ABC): lr_scheduler_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: - """Base class, to be inherited by predictors for specific type of data. + """Base class, to be inherited by model for specific type of data. Args: network_fn (Callable): The PyTorch network. network_args (Dict): Arguments for the network. - data_loader_args (Optional[Dict]): Arguments for the data loader. + data_loader (Optional[Callable]): A function that fetches train and val DataLoader. + data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. Defaults to None. @@ -53,8 +54,8 @@ class Model(ABC): # Fetch data loaders. if data_loader_args is not None: - self._data_loaders = fetch_data_loader(**data_loader_args) - dataset_name = self._data_loaders.items()[0].dataset.__name__ + self._data_loaders = data_loader(**data_loader_args) + dataset_name = self._data_loaders.__name__ else: dataset_name = "" self._data_loaders = None @@ -210,7 +211,6 @@ class Model(ABC): logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - self.save_weights() shutil.copyfile(filepath, str(path / "best.pt")) def load_weights(self) -> None: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 1570344..fd69bf2 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -32,17 +32,21 @@ class CharacterModel(Model): super().__init__( network_fn, - data_loader_args, network_args, + data_loader_args, metrics, criterion, + criterion_args, optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, device, ) self.emnist_mapping = self.mapping() self.eval() - def mapping(self) -> Dict: + def mapping(self) -> Dict[int, str]: """Mapping between integers and classes.""" mapping = load_emnist_mapping() return mapping diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/metrics.py index 905fe7b..e2a30a9 100644 --- a/src/text_recognizer/models/util.py +++ b/src/text_recognizer/models/metrics.py @@ -4,7 +4,7 @@ import torch def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: - """Short summary. + """Computes the accuracy. Args: outputs (torch.Tensor): The output from the network. |