diff options
15 files changed, 96 insertions, 84 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 2de7f09..95063bc 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -1,11 +1,12 @@ """Abstract dataset class.""" -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch.utils import data from torchvision.transforms import ToTensor +import text_recognizer.datasets.transforms as transforms from text_recognizer.datasets.util import EmnistMapper @@ -16,8 +17,8 @@ class Dataset(data.Dataset): self, train: bool, subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -27,8 +28,8 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. @@ -53,14 +54,34 @@ class Dataset(data.Dataset): self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform if transform is not None else ToTensor() - self.target_transform = ( - target_transform if target_transform is not None else torch.tensor - ) + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) self._data = None self._targets = None + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + @property def data(self) -> Tensor: """The input data.""" diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index fdd2fe6..5ae142c 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -36,6 +36,8 @@ class IamLinesDataset(Dataset): pad_token: Optional[str] = None, eos_token: Optional[str] = None, ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, subsample_fraction=subsample_fraction, diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index d2df8b5..bf5e772 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -12,7 +12,8 @@ import cv2 from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader, Dataset +import torch +from torch import Tensor from torchvision.datasets import EMNIST from tqdm import tqdm @@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -56,21 +57,21 @@ class EmnistMapper: self.eos_token = eos_token self.essentials = self._load_emnist_essentials() - # Load dataset infromation. + # Load dataset information. self._mapping = dict(self.essentials["mapping"]) self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] - def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: """Maps the token to emnist character or character index. If the token is an integer (index), the method will return the Emnist character corresponding to that index. If the token is a str (Emnist character), the method will return the corresponding index for that character. Args: - token (Union[str, int, np.uint8]): Eihter a string or index (integer). + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). Returns: Union[str, int]: The mapping result. @@ -79,9 +80,11 @@ class EmnistMapper: KeyError: If the index or string does not exist in the mapping. """ - if (isinstance(token, np.uint8) or isinstance(token, int)) and int( - token - ) in self.mapping: + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): return self.mapping[int(token)] elif isinstance(token, str) and token in self._inverse_mapping: return self._inverse_mapping[token] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index a945b41..d394b4c 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -15,8 +15,9 @@ from torch import Tensor from torch.optim.swa_utils import AveragedModel, SWALR from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary -from torchvision.transforms import Compose +from text_recognizer import datasets +from text_recognizer import networks from text_recognizer.datasets import EmnistMapper WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -27,8 +28,8 @@ class Model(ABC): def __init__( self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], + network_fn: str, + dataset: str, network_args: Optional[Dict] = None, dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -44,8 +45,8 @@ class Model(ABC): """Base class, to be inherited by model for specific type of data. Args: - network_fn (Type[nn.Module]): The PyTorch network. - dataset (Type[Dataset]): A dataset class. + network_fn (str): The name of network. + dataset (str): The name dataset class. network_args (Optional[Dict]): Arguments for the network. Defaults to None. dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. @@ -62,13 +63,15 @@ class Model(ABC): device (Optional[str]): Name of the device to train on. Defaults to None. """ + self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" # Has to be set in subclass. self._mapper = None # Placeholder. self._input_shape = None - self.dataset = dataset + self.dataset_name = dataset + self.dataset = None self.dataset_args = dataset_args # Placeholders for datasets. @@ -92,10 +95,6 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - self._name = ( - f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}" - ) - self._metrics = metrics if metrics is not None else None # Set the device. @@ -132,38 +131,12 @@ class Model(ABC): # Set this flag to true to prevent the model from configuring again. self.is_configured = True - def _configure_transforms(self) -> None: - # Load transforms. - transforms_module = importlib.import_module( - "text_recognizer.datasets.transforms" - ) - if ( - "transform" in self.dataset_args["args"] - and self.dataset_args["args"]["transform"] is not None - ): - transform_ = [] - for t in self.dataset_args["args"]["transform"]: - args = t["args"] or {} - transform_.append(getattr(transforms_module, t["type"])(**args)) - self.dataset_args["args"]["transform"] = Compose(transform_) - - if ( - "target_transform" in self.dataset_args["args"] - and self.dataset_args["args"]["target_transform"] is not None - ): - target_transform_ = [ - torch.tensor, - ] - for t in self.dataset_args["args"]["target_transform"]: - args = t["args"] or {} - target_transform_.append(getattr(transforms_module, t["type"])(**args)) - self.dataset_args["args"]["target_transform"] = Compose(target_transform_) - def prepare_data(self) -> None: """Prepare data for training.""" # TODO add downloading. if not self.data_prepared: - self._configure_transforms() + # Load dataset module. + self.dataset = getattr(datasets, self.dataset_name) # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) @@ -222,6 +195,8 @@ class Model(ABC): def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" # If no network arguments are given, load pretrained weights if they exist. + # Load network module. + network_fn = getattr(networks, network_fn) if self._network_args is None: self.load_weights(network_fn) else: diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py index 4496e40..9abe143 100644 --- a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py +++ b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py @@ -1,4 +1,6 @@ """Module for creating EMNIST Lines test support files.""" +# flake8: noqa: S106 + from pathlib import Path import shutil @@ -17,23 +19,31 @@ def create_emnist_lines_support_files() -> None: SUPPORT_DIRNAME.mkdir() # TODO: maybe have to add args to dataset. - dataset = EmnistLinesDataset() + dataset = EmnistLinesDataset( + init_token="<sos>", + pad_token="_", + eos_token="<eos>", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"}, + } + ], + ) # nosec: S106 dataset.load_or_generate_data() - for index in [0, 1, 3]: + for index in [5, 7, 9]: image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) print(image.sum(), image.dtype) - label = ( - "".join( - dataset.mapper[label] - for label in np.argmax(target[1:], dim=-1).flatten() - ) - .stip() - .strip(dataset.mapper.pad_token) + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token ) - print(label) + image = image.numpy() util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/src/text_recognizer/tests/support/create_iam_lines_support_files.py index bb568ee..50f9e3d 100644 --- a/src/text_recognizer/tests/support/create_iam_lines_support_files.py +++ b/src/text_recognizer/tests/support/create_iam_lines_support_files.py @@ -1,4 +1,5 @@ """Module for creating IAM Lines test support files.""" +# flake8: noqa from pathlib import Path import shutil @@ -17,23 +18,31 @@ def create_emnist_lines_support_files() -> None: SUPPORT_DIRNAME.mkdir() # TODO: maybe have to add args to dataset. - dataset = IamLinesDataset() + dataset = IamLinesDataset( + init_token="<sos>", + pad_token="_", + eos_token="<eos>", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"}, + } + ], + ) dataset.load_or_generate_data() for index in [0, 1, 3]: image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) print(image.sum(), image.dtype) - label = ( - "".join( - dataset.mapper[label] - for label in np.argmax(target[1:], dim=-1).flatten() - ) - .stip() - .strip(dataset.mapper.pad_token) + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token ) - print(label) + image = image.numpy() util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) diff --git a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png Binary files differnew file mode 100644 index 0000000..b7d0618 --- /dev/null +++ b/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png diff --git a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png Binary files differnew file mode 100644 index 0000000..14a8cf3 --- /dev/null +++ b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png diff --git a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png b/src/text_recognizer/tests/support/emnist_lines/they<eos>.png Binary files differnew file mode 100644 index 0000000..7f05951 --- /dev/null +++ b/src/text_recognizer/tests/support/emnist_lines/they<eos>.png diff --git a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png Binary files differnew file mode 100644 index 0000000..6eeb642 --- /dev/null +++ b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png diff --git a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png Binary files differnew file mode 100644 index 0000000..4974cf8 --- /dev/null +++ b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png diff --git a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png Binary files differnew file mode 100644 index 0000000..a731245 --- /dev/null +++ b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt deleted file mode 100644 index 04e1952..0000000 --- a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b7fa37f24951732e68bf2191fd58e9c332848d5800bf47dc1aa1a9c32c63485a -size 61946486 diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt deleted file mode 100644 index 50a6a20..0000000 --- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f2e458cf749526ffa876709fc76084d8099a490d9ff24f9be36b8b04c037f073 -size 3457858 diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index e6ae84c..55a9572 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -74,8 +74,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic """Loads all modules and arguments.""" # Load the dataset module. dataset_args = experiment_config.get("dataset", {}) - datasets_module = importlib.import_module("text_recognizer.datasets") - dataset_ = getattr(datasets_module, dataset_args["type"]) + dataset_ = dataset_args["type"] # Import the model module and model arguments. models_module = importlib.import_module("text_recognizer.models") @@ -92,8 +91,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic ) # Import network module and arguments. - network_module = importlib.import_module("text_recognizer.networks") - network_fn_ = getattr(network_module, experiment_config["network"]["type"]) + network_fn_ = experiment_config["network"]["type"] network_args = experiment_config["network"].get("args", {}) # Criterion |