diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 13 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 275 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 129 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 60 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 139 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/networks/ctc.py | 10 | ||||
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 19 | ||||
-rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/misc.py | 28 | ||||
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 1 | ||||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt | bin | 0 -> 14485310 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt | bin | 0 -> 1704174 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt | bin | 14485305 -> 14485342 bytes |
15 files changed, 346 insertions, 356 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 1b4cc59..05f74f6 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,29 +1,24 @@ """Dataset modules.""" from .emnist_dataset import ( - _augment_emnist_mapping, - _load_emnist_essentials, DATA_DIRNAME, - EmnistDataLoaders, EmnistDataset, + EmnistMapper, ESSENTIALS_FILENAME, ) from .emnist_lines_dataset import ( construct_image_from_string, - EmnistLinesDataLoaders, EmnistLinesDataset, get_samples_by_character, ) -from .util import Transpose +from .util import fetch_data_loaders, Transpose __all__ = [ - "_augment_emnist_mapping", - "_load_emnist_essentials", "construct_image_from_string", "DATA_DIRNAME", "EmnistDataset", - "EmnistDataLoaders", - "EmnistLinesDataLoaders", + "EmnistMapper", "EmnistLinesDataset", + "fetch_data_loaders", "get_samples_by_character", "Transpose", ] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f3d65ee..96f84e5 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -39,45 +39,101 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def _load_emnist_essentials() -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - -def _augment_emnist_mapping(mapping: Dict) -> Dict: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol - extra_symbols.append("_") - - max_key = max(mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - return {**mapping, **extra_mapping} +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__(self) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.essentials = self._load_emnist_essentials() + # Load dataset infromation. + self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8]): Eihter a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if (isinstance(token, np.uint8) or isinstance(token, int)) and int( + token + ) in self.mapping: + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} class EmnistDataset(Dataset): @@ -110,10 +166,12 @@ class EmnistDataset(Dataset): self.train = train self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: if not 0.0 < subsample_fraction < 1.0: raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction + self.transform = transform if self.transform is None: self.transform = Compose([Transpose(), ToTensor()]) @@ -121,17 +179,22 @@ class EmnistDataset(Dataset): self.target_transform = target_transform self.seed = seed - # Load dataset infromation. - essentials = _load_emnist_essentials() - self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - self.inverse_mapping = {v: k for k, v in self.mapping.items()} - self.num_classes = len(self.mapping) - self.input_shape = essentials["input_shape"] + self._mapper = EmnistMapper() + self.input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes # Placeholders self.data = None self.targets = None + # Load dataset. + self.load_emnist_dataset() + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -162,13 +225,18 @@ class EmnistDataset(Dataset): return data, targets + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistDataset" + def __repr__(self) -> str: """Returns information about the dataset.""" return ( "EMNIST Dataset\n" f"Num classes: {self.num_classes}\n" - f"Mapping: {self.mapping}\n" f"Input shape: {self.input_shape}\n" + f"Mapping: {self.mapper.mapping}\n" ) def _sample_to_balance(self) -> None: @@ -217,118 +285,3 @@ class EmnistDataset(Dataset): if self.subsample_fraction is not None: self._subsample() - - -class EmnistDataLoaders: - """Class for Emnist DataLoaders.""" - - def __init__( - self, - splits: List[str], - sample_to_balance: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, - seed: int = 4711, - ) -> None: - """Fetches DataLoaders for given split(s). - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. - Defaults to False. - subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire - dataset will be loaded. - transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a - transformed version. E.g, transforms.RandomCrop. Defaults to None. - target_transform (Optional[Callable]): A function/transform that takes in the target and - transforms it. Defaults to None. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - seed (int): Seed for sampling. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.splits = splits - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - - self.dataset_args = { - "sample_to_balance": sample_to_balance, - "subsample_fraction": subsample_fraction, - "transform": transform, - "target_transform": target_transform, - "seed": seed, - } - self.batch_size = batch_size - self.shuffle = shuffle - self.num_workers = num_workers - self.cuda = cuda - self._data_loaders = self._load_data_loaders() - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return self._data_loaders[self.splits[0]].dataset.__repr__() - - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "Emnist" - - def __call__(self, split: str) -> DataLoader: - """Returns the `split` DataLoader. - - Args: - split (str): The dataset split, i.e. train or val. - - Returns: - DataLoader: A PyTorch DataLoader. - - Raises: - ValueError: If the split does not exist. - - """ - try: - return self._data_loaders[split] - except KeyError: - raise ValueError(f"Split {split} does not exist.") - - def _load_data_loaders(self) -> Dict[str, DataLoader]: - """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" - data_loaders = {} - - for split in ["train", "val"]: - if split in self.splits: - - if split == "train": - self.dataset_args["train"] = True - else: - self.dataset_args["train"] = False - - emnist_dataset = EmnistDataset(**self.dataset_args) - - emnist_dataset.load_emnist_dataset() - - data_loader = DataLoader( - dataset=emnist_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - pin_memory=self.cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 1c6e959..d64a991 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor from text_recognizer.datasets import ( - _augment_emnist_mapping, - _load_emnist_essentials, DATA_DIRNAME, EmnistDataset, + EmnistMapper, ESSENTIALS_FILENAME, ) from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset): def __init__( self, train: bool = False, - emnist: Optional[EmnistDataset] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, max_length: int = 34, @@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset): num_samples: int = 10000, seed: int = 4711, ) -> None: - """Short summary. + """Set attributes and loads the dataset. Args: - emnist (EmnistDataset): A EmnistDataset object. train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. @@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset): """ self.train = train - self.emnist = emnist self.transform = transform if self.transform is None: @@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset): if self.target_transform is None: self.target_transform = torch.tensor - # Load emnist dataset infromation. - essentials = _load_emnist_essentials() - self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - self.num_classes = len(self.mapping) - self.input_shape = essentials["input_shape"] + # Extract dataset information. + self._mapper = EmnistMapper() + self.input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes self.max_length = max_length self.min_overlap = min_overlap @@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset): self.output_shape = (self.max_length, self.num_classes) self.seed = seed - # Placeholders for the generated dataset. + # Placeholders for the dataset. self.data = None self.target = None + # Load dataset. + self._load_or_generate_data() + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset): if torch.is_tensor(index): index = index.tolist() - # data = np.array([self.data[index]]) data = self.data[index] targets = self.targets[index] @@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset): return data, targets + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistLinesDataset" + def __repr__(self) -> str: """Returns information about the dataset.""" return ( @@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset): ) @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - self.emnist.load_emnist_dataset() + emnist = EmnistDataset(train=self.train, sample_to_balance=True) + samples_by_character = get_samples_by_character( - self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, + emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, ) DATA_DIRNAME.mkdir(parents=True, exist_ok=True) @@ -332,94 +340,3 @@ def create_datasets( num_samples=num, ) emnist_lines._load_or_generate_data() - - -class EmnistLinesDataLoaders: - """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" - - def __init__( - self, - splits: List[str], - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_samples: int = 10000, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, - seed: int = 4711, - ) -> None: - """Sets the data loader arguments.""" - self.splits = splits - self.dataset_args = { - "max_length": max_length, - "min_overlap": min_overlap, - "max_overlap": max_overlap, - "num_samples": num_samples, - "transform": transform, - "target_transform": target_transform, - "seed": seed, - } - self.batch_size = batch_size - self.shuffle = shuffle - self.num_workers = num_workers - self.cuda = cuda - self._data_loaders = self._load_data_loaders() - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return self._data_loaders[self.splits[0]].dataset.__repr__() - - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistLines" - - def __call__(self, split: str) -> DataLoader: - """Returns the `split` DataLoader. - - Args: - split (str): The dataset split, i.e. train or val. - - Returns: - DataLoader: A PyTorch DataLoader. - - Raises: - ValueError: If the split does not exist. - - """ - try: - return self._data_loaders[split] - except KeyError: - raise ValueError(f"Split {split} does not exist.") - - def _load_data_loaders(self) -> Dict[str, DataLoader]: - """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" - data_loaders = {} - - for split in ["train", "val"]: - if split in self.splits: - - if split == "train": - self.dataset_args["train"] = True - else: - self.dataset_args["train"] = False - - emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) - - emnist_lines_dataset._load_or_generate_data() - - data_loader = DataLoader( - dataset=emnist_lines_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - pin_memory=self.cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 6668eef..321bc67 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,6 +1,10 @@ """Util functions for datasets.""" +import importlib +from typing import Callable, Dict, List, Type + import numpy as np from PIL import Image +from torch.utils.data import DataLoader, Dataset class Transpose: @@ -9,3 +13,59 @@ class Transpose: def __call__(self, image: Image) -> np.ndarray: """Swaps axis.""" return np.array(image).swapaxes(0, 1) + + +def fetch_data_loaders( + splits: List[str], + dataset: str, + dataset_args: Dict, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, +) -> Dict[str, DataLoader]: + """Fetches DataLoaders for given split(s) as a dictionary. + + Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then + calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. + + Args: + splits (List[str]): One or both of the dataset splits "train" and "val". + dataset (str): The name of the dataset. + dataset_args (Dict): The dataset arguments. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. + + Returns: + Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. + + """ + + def check_dataset_args(args: Dict, split: str) -> Dict: + args["train"] = True if split == "train" else False + return args + + # Import dataset module. + datasets_module = importlib.import_module("text_recognizer.datasets") + dataset_ = getattr(datasets_module, dataset) + + data_loaders = {} + + for split in ["train", "val"]: + if split in splits: + + data_loader = DataLoader( + dataset=dataset_(**check_dataset_args(dataset_args, split)), + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 84a86ca..6d40b49 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -12,6 +12,7 @@ import torch from torch import nn from torchsummary import summary +from text_recognizer.datasets import EmnistMapper, fetch_data_loaders WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -23,7 +24,6 @@ class Model(ABC): self, network_fn: Type[nn.Module], network_args: Optional[Dict] = None, - data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -39,7 +39,6 @@ class Model(ABC): Args: network_fn (Type[nn.Module]): The PyTorch network. network_args (Optional[Dict]): Arguments for the network. Defaults to None. - data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. @@ -54,15 +53,11 @@ class Model(ABC): """ - # Fetch data loaders. - if data_loader_args is not None: - self._data_loaders = data_loader(**data_loader_args) - dataset_name = self._data_loaders.__name__ - self._mapping = self._data_loaders.mapping - else: - self._mapping = None - dataset_name = "*" - self._data_loaders = None + # Fetch data loaders and dataset info. + dataset_name, self._data_loaders, self._mapper = self._load_data_loader( + data_loader_args + ) + self._input_shape = self._mapper.input_shape self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" @@ -76,40 +71,15 @@ class Model(ABC): self._device = device # Load network. - self._network = None - self._network_args = network_args - # If no network arguemnts are given, load pretrained weights if they exist. - if self._network_args is None: - self.load_weights(network_fn) - else: - self._network = network_fn(**self._network_args) + self._network, self._network_args = self._load_network(network_fn, network_args) # To device. self._network.to(self._device) - # Set criterion. - self._criterion = None - if criterion is not None: - self._criterion = criterion(**criterion_args) - - # Set optimizer. - self._optimizer = None - if optimizer is not None: - self._optimizer = optimizer(self._network.parameters(), **optimizer_args) - - # Set learning rate scheduler. - self._lr_scheduler = None - if lr_scheduler is not None: - # OneCycleLR needs the number of steps in an epoch as an input argument. - if "OneCycleLR" in str(lr_scheduler): - lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) - self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) - - # Extract the input shape for the torchsummary. - if isinstance(self._network_args["input_size"], int): - self._input_shape = (1,) + tuple([self._network_args["input_size"]]) - else: - self._input_shape = (1,) + tuple(self._network_args["input_size"]) + # Set training objects. + self._criterion = self._load_criterion(criterion, criterion_args) + self._optimizer = self._load_optimizer(optimizer, optimizer_args) + self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args) # Experiment directory. self.model_dir = None @@ -117,6 +87,64 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False + def _load_data_loader( + self, data_loader_args: Optional[Dict] + ) -> Tuple[str, Dict, EmnistMapper]: + """Loads data loader, dataset name, and dataset mapper.""" + if data_loader_args is not None: + data_loaders = fetch_data_loaders(**data_loader_args) + dataset = list(data_loaders.values())[0].dataset + dataset_name = dataset.__name__ + mapper = dataset.mapper + else: + self._mapper = EmnistMapper() + dataset_name = "*" + data_loaders = None + return dataset_name, data_loaders, mapper + + def _load_network( + self, network_fn: Type[nn.Module], network_args: Optional[Dict] + ) -> Tuple[Type[nn.Module], Dict]: + """Loads the network.""" + # If no network arguemnts are given, load pretrained weights if they exist. + if network_args is None: + network, network_args = self.load_weights(network_fn) + else: + network = network_fn(**network_args) + return network, network_args + + def _load_criterion( + self, criterion: Optional[Callable], criterion_args: Optional[Dict] + ) -> Optional[Callable]: + """Loads the criterion.""" + if criterion is not None: + _criterion = criterion(**criterion_args) + else: + _criterion = None + return _criterion + + def _load_optimizer( + self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] + ) -> Optional[Callable]: + """Loads the optimizer.""" + if optimizer is not None: + _optimizer = optimizer(self._network.parameters(), **optimizer_args) + else: + _optimizer = None + return _optimizer + + def _load_lr_scheduler( + self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] + ) -> Optional[Callable]: + """Loads learning rate scheduler.""" + if self._optimizer and lr_scheduler is not None: + if "OneCycleLR" in str(lr_scheduler): + lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) + _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + else: + _lr_scheduler = None + return _lr_scheduler + @property def __name__(self) -> str: """Returns the name of the model.""" @@ -128,9 +156,14 @@ class Model(ABC): return self._input_shape @property + def mapper(self) -> Dict: + """Returns the mapper that maps between ints and chars.""" + return self._mapper + + @property def mapping(self) -> Dict: - """Returns the class mapping.""" - return self._mapping + """Returns the mapping between network output and Emnist character.""" + return self._mapper.mapping def eval(self) -> None: """Sets the network to evaluation mode.""" @@ -184,7 +217,11 @@ class Model(ABC): def summary(self) -> None: """Prints a summary of the network architecture.""" device = re.sub("[^A-Za-z]+", "", self.device) - summary(self._network, self._input_shape, device=device) + if self._input_shape is not None: + input_shape = (1,) + tuple(self._input_shape) + summary(self._network, input_shape, device=device) + else: + logger.warning("Could not print summary as input shape is not set.") def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" @@ -218,8 +255,9 @@ class Model(ABC): if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - if self._lr_scheduler is not None: - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs. + # if self._lr_scheduler is not None: + # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) epoch = checkpoint["epoch"] @@ -257,7 +295,7 @@ class Model(ABC): ) shutil.copyfile(filepath, str(self.model_dir / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> None: + def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -267,12 +305,13 @@ class Model(ABC): ) # Loading state directory. state_dict = torch.load(filename, map_location=torch.device(self._device)) - self._network_args = state_dict["network_args"] + network_args = state_dict["network_args"] weights = state_dict["model_state"] # Initializes the network with trained weights. - self._network = network_fn(**self._network_args) - self._network.load_state_dict(weights) + network = network_fn(**self._network_args) + network.load_state_dict(weights) + return network, network_args def save_weights(self, path: Path) -> None: """Save the network weights.""" diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index f1dabb7..0a0ab2d 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -6,10 +6,6 @@ import torch from torch import nn from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import ( - _augment_emnist_mapping, - _load_emnist_essentials, -) from text_recognizer.models.base import Model @@ -20,7 +16,6 @@ class CharacterModel(Model): self, network_fn: Type[nn.Module], network_args: Optional[Dict] = None, - data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -36,7 +31,6 @@ class CharacterModel(Model): super().__init__( network_fn, network_args, - data_loader, data_loader_args, metrics, criterion, @@ -47,16 +41,9 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - if self.mapping is None: - self.load_mapping() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) - def load_mapping(self) -> None: - """Mapping between integers and classes.""" - essentials = _load_emnist_essentials() - self._mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - def predict_on_image( self, image: Union[np.ndarray, torch.Tensor] ) -> Tuple[str, float]: @@ -86,6 +73,6 @@ class CharacterModel(Model): index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] - predicted_character = self._mapping[index] + predicted_character = self.mapper(index) return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py new file mode 100644 index 0000000..00ad47e --- /dev/null +++ b/src/text_recognizer/networks/ctc.py @@ -0,0 +1,10 @@ +"""Decodes the CTC output.""" +# +# from typing import Tuple +# import torch +# +# +# def greedy_decoder( +# output, labels, label_length, blank_label, collapse_repeated=True +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# pass diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 2839a0c..cbc58fc 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,24 +1,16 @@ """Defines the LeNet network.""" from typing import Callable, Dict, Optional, Tuple +from einops.layers.torch import Rearrange import torch from torch import nn -class Flatten(nn.Module): - """Flattens a tensor.""" - - def forward(self, x: int) -> torch.Tensor: - """Flattens a tensor for input to a nn.Linear layer.""" - return torch.flatten(x, start_dim=1) - - class LeNet(nn.Module): """LeNet network.""" def __init__( self, - input_size: Tuple[int, ...] = (1, 28, 28), channels: Tuple[int, ...] = (1, 32, 64), kernel_sizes: Tuple[int, ...] = (3, 3, 2), hidden_size: Tuple[int, ...] = (9216, 128), @@ -30,7 +22,6 @@ class LeNet(nn.Module): """The LeNet network. Args: - input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28). channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. @@ -44,10 +35,9 @@ class LeNet(nn.Module): """ super().__init__() - self._input_size = input_size - if activation_fn is not None: - activation_fn = getattr(nn, activation_fn)(activation_fn_args) + activation_fn_args = activation_fn_args or {} + activation_fn = getattr(nn, activation_fn)(**activation_fn_args) else: activation_fn = nn.ReLU(inplace=True) @@ -66,7 +56,7 @@ class LeNet(nn.Module): activation_fn, nn.MaxPool2d(kernel_sizes[2]), nn.Dropout(p=dropout_rate), - Flatten(), + Rearrange("b c h w -> b (c h w)"), nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), activation_fn, nn.Dropout(p=dropout_rate), @@ -77,6 +67,7 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" + # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) return self.layers(x) diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py new file mode 100644 index 0000000..d704139 --- /dev/null +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -0,0 +1,4 @@ +"""LSTM with CTC for handwritten text recognition within a line.""" + +import torch +from torch import nn diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py new file mode 100644 index 0000000..9440f9d --- /dev/null +++ b/src/text_recognizer/networks/misc.py @@ -0,0 +1,28 @@ +"""Miscellaneous neural network functionality.""" +from typing import Tuple + +from einops import rearrange +import torch +from torch.nn import Unfold + + +def sliding_window( + images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: + """Creates patches of an image. + + Args: + images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). + patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. + stride (Tuple[int, int]): The stride of the sliding window. + + Returns: + torch.Tensor: A tensor with the shape (batch, patches, height, width). + + """ + unfold = Unfold(kernel_size=patch_size, stride=stride) + patches = unfold(images) + patches = rearrange( + patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1] + ) + return patches diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index d704d99..ac2c825 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,6 +1,7 @@ """Defines the MLP network.""" from typing import Callable, Dict, List, Optional, Union +from einops.layers.torch import Rearrange import torch from torch import nn @@ -34,7 +35,8 @@ class MLP(nn.Module): super().__init__() if activation_fn is not None: - activation_fn = getattr(nn, activation_fn)(activation_fn_args) + activation_fn_args = activation_fn_args or {} + activation_fn = getattr(nn, activation_fn)(**activation_fn_args) else: activation_fn = nn.ReLU(inplace=True) @@ -42,6 +44,7 @@ class MLP(nn.Module): hidden_size = [hidden_size] * num_layers self.layers = [ + Rearrange("b c h w -> b (c h w)"), nn.Linear(in_features=input_size, out_features=hidden_size[0]), activation_fn, ] @@ -63,7 +66,9 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" - x = torch.flatten(x, start_dim=1) + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) return self.layers(x) @property diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py new file mode 100644 index 0000000..23394b0 --- /dev/null +++ b/src/text_recognizer/networks/residual_network.py @@ -0,0 +1 @@ +"""Residual CNN.""" diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt Binary files differnew file mode 100644 index 0000000..81ef9be --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt Binary files differnew file mode 100644 index 0000000..49bd166 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt Binary files differindex 46b1cb1..ed73c09 100644 --- a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt |