From 75909723fa2b1f6245d5c5422e4f2e88b8a26052 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 15 Nov 2020 17:40:44 +0100 Subject: Able to generate support files for lines datasets. --- src/text_recognizer/datasets/dataset.py | 39 ++++++++++++---- src/text_recognizer/datasets/iam_lines_dataset.py | 2 + src/text_recognizer/datasets/util.py | 19 ++++---- src/text_recognizer/models/base.py | 51 ++++++--------------- .../support/create_emnist_lines_support_files.py | 30 ++++++++---- .../support/create_iam_lines_support_files.py | 27 +++++++---- .../tests/support/emnist_lines/Knox Ky.png | Bin 0 -> 2301 bytes .../emnist_lines/ancillary beliefs and.png | Bin 0 -> 5424 bytes .../tests/support/emnist_lines/they.png | Bin 0 -> 1391 bytes .../He rose from his breakfast-nook bench.png | Bin 0 -> 5170 bytes .../and came into the livingroom, where.png | Bin 0 -> 3617 bytes .../his entrance. He came, almost falling.png | Bin 0 -> 3923 bytes ...istLinesDataset_LineRecurrentNetwork_weights.pt | 3 -- ...IamLinesDataset_LineRecurrentNetwork_weights.pt | 3 -- src/training/run_experiment.py | 6 +-- 15 files changed, 96 insertions(+), 84 deletions(-) create mode 100644 src/text_recognizer/tests/support/emnist_lines/Knox Ky.png create mode 100644 src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png create mode 100644 src/text_recognizer/tests/support/emnist_lines/they.png create mode 100644 src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png create mode 100644 src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png create mode 100644 src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png delete mode 100644 src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt delete mode 100644 src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 2de7f09..95063bc 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -1,11 +1,12 @@ """Abstract dataset class.""" -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch.utils import data from torchvision.transforms import ToTensor +import text_recognizer.datasets.transforms as transforms from text_recognizer.datasets.util import EmnistMapper @@ -16,8 +17,8 @@ class Dataset(data.Dataset): self, train: bool, subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -27,8 +28,8 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. @@ -53,14 +54,34 @@ class Dataset(data.Dataset): self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform if transform is not None else ToTensor() - self.target_transform = ( - target_transform if target_transform is not None else torch.tensor - ) + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) self._data = None self._targets = None + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + @property def data(self) -> Tensor: """The input data.""" diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index fdd2fe6..5ae142c 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -36,6 +36,8 @@ class IamLinesDataset(Dataset): pad_token: Optional[str] = None, eos_token: Optional[str] = None, ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, subsample_fraction=subsample_fraction, diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index d2df8b5..bf5e772 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -12,7 +12,8 @@ import cv2 from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader, Dataset +import torch +from torch import Tensor from torchvision.datasets import EMNIST from tqdm import tqdm @@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -56,21 +57,21 @@ class EmnistMapper: self.eos_token = eos_token self.essentials = self._load_emnist_essentials() - # Load dataset infromation. + # Load dataset information. self._mapping = dict(self.essentials["mapping"]) self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] - def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: """Maps the token to emnist character or character index. If the token is an integer (index), the method will return the Emnist character corresponding to that index. If the token is a str (Emnist character), the method will return the corresponding index for that character. Args: - token (Union[str, int, np.uint8]): Eihter a string or index (integer). + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). Returns: Union[str, int]: The mapping result. @@ -79,9 +80,11 @@ class EmnistMapper: KeyError: If the index or string does not exist in the mapping. """ - if (isinstance(token, np.uint8) or isinstance(token, int)) and int( - token - ) in self.mapping: + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): return self.mapping[int(token)] elif isinstance(token, str) and token in self._inverse_mapping: return self._inverse_mapping[token] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index a945b41..d394b4c 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -15,8 +15,9 @@ from torch import Tensor from torch.optim.swa_utils import AveragedModel, SWALR from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary -from torchvision.transforms import Compose +from text_recognizer import datasets +from text_recognizer import networks from text_recognizer.datasets import EmnistMapper WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -27,8 +28,8 @@ class Model(ABC): def __init__( self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], + network_fn: str, + dataset: str, network_args: Optional[Dict] = None, dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -44,8 +45,8 @@ class Model(ABC): """Base class, to be inherited by model for specific type of data. Args: - network_fn (Type[nn.Module]): The PyTorch network. - dataset (Type[Dataset]): A dataset class. + network_fn (str): The name of network. + dataset (str): The name dataset class. network_args (Optional[Dict]): Arguments for the network. Defaults to None. dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. @@ -62,13 +63,15 @@ class Model(ABC): device (Optional[str]): Name of the device to train on. Defaults to None. """ + self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" # Has to be set in subclass. self._mapper = None # Placeholder. self._input_shape = None - self.dataset = dataset + self.dataset_name = dataset + self.dataset = None self.dataset_args = dataset_args # Placeholders for datasets. @@ -92,10 +95,6 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - self._name = ( - f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}" - ) - self._metrics = metrics if metrics is not None else None # Set the device. @@ -132,38 +131,12 @@ class Model(ABC): # Set this flag to true to prevent the model from configuring again. self.is_configured = True - def _configure_transforms(self) -> None: - # Load transforms. - transforms_module = importlib.import_module( - "text_recognizer.datasets.transforms" - ) - if ( - "transform" in self.dataset_args["args"] - and self.dataset_args["args"]["transform"] is not None - ): - transform_ = [] - for t in self.dataset_args["args"]["transform"]: - args = t["args"] or {} - transform_.append(getattr(transforms_module, t["type"])(**args)) - self.dataset_args["args"]["transform"] = Compose(transform_) - - if ( - "target_transform" in self.dataset_args["args"] - and self.dataset_args["args"]["target_transform"] is not None - ): - target_transform_ = [ - torch.tensor, - ] - for t in self.dataset_args["args"]["target_transform"]: - args = t["args"] or {} - target_transform_.append(getattr(transforms_module, t["type"])(**args)) - self.dataset_args["args"]["target_transform"] = Compose(target_transform_) - def prepare_data(self) -> None: """Prepare data for training.""" # TODO add downloading. if not self.data_prepared: - self._configure_transforms() + # Load dataset module. + self.dataset = getattr(datasets, self.dataset_name) # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) @@ -222,6 +195,8 @@ class Model(ABC): def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" # If no network arguments are given, load pretrained weights if they exist. + # Load network module. + network_fn = getattr(networks, network_fn) if self._network_args is None: self.load_weights(network_fn) else: diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py index 4496e40..9abe143 100644 --- a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py +++ b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py @@ -1,4 +1,6 @@ """Module for creating EMNIST Lines test support files.""" +# flake8: noqa: S106 + from pathlib import Path import shutil @@ -17,23 +19,31 @@ def create_emnist_lines_support_files() -> None: SUPPORT_DIRNAME.mkdir() # TODO: maybe have to add args to dataset. - dataset = EmnistLinesDataset() + dataset = EmnistLinesDataset( + init_token="", + pad_token="_", + eos_token="", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "", "pad_token": "_", "eos_token": ""}, + } + ], + ) # nosec: S106 dataset.load_or_generate_data() - for index in [0, 1, 3]: + for index in [5, 7, 9]: image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) print(image.sum(), image.dtype) - label = ( - "".join( - dataset.mapper[label] - for label in np.argmax(target[1:], dim=-1).flatten() - ) - .stip() - .strip(dataset.mapper.pad_token) + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token ) - print(label) + image = image.numpy() util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/src/text_recognizer/tests/support/create_iam_lines_support_files.py index bb568ee..50f9e3d 100644 --- a/src/text_recognizer/tests/support/create_iam_lines_support_files.py +++ b/src/text_recognizer/tests/support/create_iam_lines_support_files.py @@ -1,4 +1,5 @@ """Module for creating IAM Lines test support files.""" +# flake8: noqa from pathlib import Path import shutil @@ -17,23 +18,31 @@ def create_emnist_lines_support_files() -> None: SUPPORT_DIRNAME.mkdir() # TODO: maybe have to add args to dataset. - dataset = IamLinesDataset() + dataset = IamLinesDataset( + init_token="", + pad_token="_", + eos_token="", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "", "pad_token": "_", "eos_token": ""}, + } + ], + ) dataset.load_or_generate_data() for index in [0, 1, 3]: image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) print(image.sum(), image.dtype) - label = ( - "".join( - dataset.mapper[label] - for label in np.argmax(target[1:], dim=-1).flatten() - ) - .stip() - .strip(dataset.mapper.pad_token) + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token ) - print(label) + image = image.numpy() util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) diff --git a/src/text_recognizer/tests/support/emnist_lines/Knox Ky.png b/src/text_recognizer/tests/support/emnist_lines/Knox Ky.png new file mode 100644 index 0000000..b7d0618 Binary files /dev/null and b/src/text_recognizer/tests/support/emnist_lines/Knox Ky.png differ diff --git a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png new file mode 100644 index 0000000..14a8cf3 Binary files /dev/null and b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and.png differ diff --git a/src/text_recognizer/tests/support/emnist_lines/they.png b/src/text_recognizer/tests/support/emnist_lines/they.png new file mode 100644 index 0000000..7f05951 Binary files /dev/null and b/src/text_recognizer/tests/support/emnist_lines/they.png differ diff --git a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png new file mode 100644 index 0000000..6eeb642 Binary files /dev/null and b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench.png differ diff --git a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png new file mode 100644 index 0000000..4974cf8 Binary files /dev/null and b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where.png differ diff --git a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png new file mode 100644 index 0000000..a731245 Binary files /dev/null and b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling.png differ diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt deleted file mode 100644 index 04e1952..0000000 --- a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b7fa37f24951732e68bf2191fd58e9c332848d5800bf47dc1aa1a9c32c63485a -size 61946486 diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt deleted file mode 100644 index 50a6a20..0000000 --- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f2e458cf749526ffa876709fc76084d8099a490d9ff24f9be36b8b04c037f073 -size 3457858 diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index e6ae84c..55a9572 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -74,8 +74,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic """Loads all modules and arguments.""" # Load the dataset module. dataset_args = experiment_config.get("dataset", {}) - datasets_module = importlib.import_module("text_recognizer.datasets") - dataset_ = getattr(datasets_module, dataset_args["type"]) + dataset_ = dataset_args["type"] # Import the model module and model arguments. models_module = importlib.import_module("text_recognizer.models") @@ -92,8 +91,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic ) # Import network module and arguments. - network_module = importlib.import_module("text_recognizer.networks") - network_fn_ = getattr(network_module, experiment_config["network"]["type"]) + network_fn_ = experiment_config["network"]["type"] network_args = experiment_config["network"].get("args", {}) # Criterion -- cgit v1.2.3-70-g09d2