summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py2
-rw-r--r--src/text_recognizer/datasets/data_loader.py15
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py258
-rw-r--r--src/text_recognizer/models/base.py12
-rw-r--r--src/text_recognizer/models/character_model.py8
-rw-r--r--src/text_recognizer/models/metrics.py (renamed from src/text_recognizer/models/util.py)2
6 files changed, 169 insertions, 128 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 929cfb9..aec5bf9 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,2 +1,2 @@
"""Dataset modules."""
-from .data_loader import fetch_data_loader
+from .emnist_dataset import EmnistDataLoader
diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py
deleted file mode 100644
index fd55934..0000000
--- a/src/text_recognizer/datasets/data_loader.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Data loader collection."""
-
-from typing import Dict
-
-from torch.utils.data import DataLoader
-
-from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader
-
-
-def fetch_data_loader(data_loader_args: Dict) -> DataLoader:
- """Fetches the specified PyTorch data loader."""
- if data_loader_args.pop("name").lower() == "emnist":
- return fetch_emnist_data_loader(data_loader_args)
- else:
- raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index f9c8ffa..a17d7a9 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -24,7 +24,7 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
-def save_emnist_essentials(emnsit_dataset: EMNIST) -> None:
+def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
labels.sort()
@@ -45,111 +45,163 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def load_emnist_mapping() -> Dict:
+def load_emnist_mapping() -> Dict[int, str]:
"""Load the EMNIST mapping."""
with open(str(ESSENTIALS_FILENAME)) as f:
essentials = json.load(f)
return dict(essentials["mapping"])
-def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None:
- """Because the dataset is not balanced, we take at most the mean number of instances per class."""
- np.random.seed(seed)
- x = dataset.data
- y = dataset.targets
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_inds = []
- for label in np.unique(y.flatten()):
- inds = np.where(y == label)[0]
- sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
- all_sampled_inds.append(sampled_inds)
- ind = np.concatenate(all_sampled_inds)
- x_sampled = x[ind]
- y_sampled = y[ind]
- dataset.data = x_sampled
- dataset.targets = y_sampled
-
-
-def fetch_emnist_dataset(
- split: str,
- train: bool,
- sample_to_balance: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
-) -> EMNIST:
- """Fetch the EMNIST dataset."""
- if transform is None:
- transform = Compose([Transpose(), ToTensor()])
-
- dataset = EMNIST(
- root=DATA_DIRNAME,
- split="byclass",
- train=train,
- download=False,
- transform=transform,
- target_transform=target_transform,
- )
-
- if sample_to_balance and split == "byclass":
- _sample_to_balance(dataset)
-
- return dataset
-
-
-def fetch_emnist_data_loader(
- splits: List[str],
- sample_to_balance: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
-) -> Dict[DataLoader]:
- """Fetches the EMNIST dataset and return a PyTorch DataLoader.
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
- Defaults to False.
- transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
- transformed version. E.g, transforms.RandomCrop. Defaults to None.
- target_transform (Optional[Callable]): A function/transform that takes in the target and transforms
- it.
- Defaults to None.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
-
- Returns:
- Dict: A dict containing PyTorch DataLoader(s) with emnist characters.
-
- """
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in splits:
-
- if split == "train":
- train = True
- else:
- train = False
-
- dataset = fetch_emnist_dataset(
- split, train, sample_to_balance, transform, target_transform
- )
-
- data_loader = DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
+class EmnistDataLoader:
+ """Class for Emnist DataLoaders."""
+
+ def __init__(
+ self,
+ splits: List[str],
+ sample_to_balance: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+ seed: int = 4711,
+ ) -> None:
+ """Fetches DataLoaders.
+
+ Args:
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
+ Defaults to False.
+ subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire
+ dataset will be loaded.
+ transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
+ transformed version. E.g, transforms.RandomCrop. Defaults to None.
+ target_transform (Optional[Callable]): A function/transform that takes in the target and
+ transforms it. Defaults to None.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
+ seed (int): Seed for sampling.
+
+ """
+ self.splits = splits
+ self.sample_to_balance = sample_to_balance
+ if subsample_fraction is not None:
+ assert (
+ 0.0 < subsample_fraction < 1.0
+ ), " The subsample fraction must be in (0, 1)."
+ self.subsample_fraction = subsample_fraction
+ self.transform = transform
+ self.target_transform = target_transform
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.num_workers = num_workers
+ self.cuda = cuda
+ self._data_loaders = self._fetch_emnist_data_loaders()
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EMNIST"
+
+ def __call__(self, split: str) -> Optional[DataLoader]:
+ """Returns the `split` DataLoader.
+
+ Args:
+ split (str): The dataset split, i.e. train or val.
+
+ Returns:
+ type: A PyTorch DataLoader.
+
+ Raises:
+ ValueError: If the split does not exist.
+
+ """
+ try:
+ return self._data_loaders[split]
+ except KeyError:
+ raise ValueError(f"Split {split} does not exist.")
+
+ def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(self.seed)
+ x = dataset.data
+ y = dataset.targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+ return dataset
+
+ def _subsample(self, dataset: type = EMNIST) -> EMNIST:
+ """Subsamples the dataset to the specified fraction."""
+ x = dataset.data
+ y = dataset.targets
+ num_samples = int(x.shape[0] * self.subsample_fraction)
+ x_sampled = x[:num_samples]
+ y_sampled = y[:num_samples]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+ return dataset
+
+ def _fetch_emnist_dataset(self, train: bool) -> EMNIST:
+ """Fetch the EMNIST dataset."""
+ if self.transform is None:
+ transform = Compose([Transpose(), ToTensor()])
+
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=train,
+ download=False,
+ transform=transform,
+ target_transform=self.target_transform,
+ )
+
+ if self.sample_to_balance:
+ dataset = self._sample_to_balance(dataset)
+
+ if self.subsample_fraction is not None:
+ dataset = self._subsample(dataset)
+
+ return dataset
+
+ def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
+ """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in self.splits:
+
+ if split == "train":
+ train = True
+ else:
+ train = False
+
+ dataset = self._fetch_emnist_dataset(train)
+
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ pin_memory=self.cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 736af7b..0cc531a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -10,7 +10,6 @@ import torch
from torch import nn
from torchsummary import summary
-from text_recognizer.dataset.data_loader import fetch_data_loader
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -22,6 +21,7 @@ class Model(ABC):
self,
network_fn: Callable,
network_args: Dict,
+ data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -32,12 +32,13 @@ class Model(ABC):
lr_scheduler_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
- """Base class, to be inherited by predictors for specific type of data.
+ """Base class, to be inherited by model for specific type of data.
Args:
network_fn (Callable): The PyTorch network.
network_args (Dict): Arguments for the network.
- data_loader_args (Optional[Dict]): Arguments for the data loader.
+ data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
+ data_loader_args (Optional[Dict]): Arguments for the DataLoader.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
Defaults to None.
@@ -53,8 +54,8 @@ class Model(ABC):
# Fetch data loaders.
if data_loader_args is not None:
- self._data_loaders = fetch_data_loader(**data_loader_args)
- dataset_name = self._data_loaders.items()[0].dataset.__name__
+ self._data_loaders = data_loader(**data_loader_args)
+ dataset_name = self._data_loaders.__name__
else:
dataset_name = ""
self._data_loaders = None
@@ -210,7 +211,6 @@ class Model(ABC):
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- self.save_weights()
shutil.copyfile(filepath, str(path / "best.pt"))
def load_weights(self) -> None:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 1570344..fd69bf2 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -32,17 +32,21 @@ class CharacterModel(Model):
super().__init__(
network_fn,
- data_loader_args,
network_args,
+ data_loader_args,
metrics,
criterion,
+ criterion_args,
optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
device,
)
self.emnist_mapping = self.mapping()
self.eval()
- def mapping(self) -> Dict:
+ def mapping(self) -> Dict[int, str]:
"""Mapping between integers and classes."""
mapping = load_emnist_mapping()
return mapping
diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/metrics.py
index 905fe7b..e2a30a9 100644
--- a/src/text_recognizer/models/util.py
+++ b/src/text_recognizer/models/metrics.py
@@ -4,7 +4,7 @@ import torch
def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
- """Short summary.
+ """Computes the accuracy.
Args:
outputs (torch.Tensor): The output from the network.