diff options
Diffstat (limited to 'src/text_recognizer')
| -rw-r--r-- | src/text_recognizer/datasets/__init__.py | 13 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 275 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 129 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/util.py | 60 | ||||
| -rw-r--r-- | src/text_recognizer/models/base.py | 139 | ||||
| -rw-r--r-- | src/text_recognizer/models/character_model.py | 15 | ||||
| -rw-r--r-- | src/text_recognizer/networks/ctc.py | 10 | ||||
| -rw-r--r-- | src/text_recognizer/networks/lenet.py | 19 | ||||
| -rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 4 | ||||
| -rw-r--r-- | src/text_recognizer/networks/misc.py | 28 | ||||
| -rw-r--r-- | src/text_recognizer/networks/mlp.py | 9 | ||||
| -rw-r--r-- | src/text_recognizer/networks/residual_network.py | 1 | ||||
| -rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt | bin | 0 -> 14485310 bytes | |||
| -rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt | bin | 0 -> 1704174 bytes | |||
| -rw-r--r-- | src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt | bin | 14485305 -> 14485342 bytes | 
15 files changed, 346 insertions, 356 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 1b4cc59..05f74f6 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,29 +1,24 @@  """Dataset modules."""  from .emnist_dataset import ( -    _augment_emnist_mapping, -    _load_emnist_essentials,      DATA_DIRNAME, -    EmnistDataLoaders,      EmnistDataset, +    EmnistMapper,      ESSENTIALS_FILENAME,  )  from .emnist_lines_dataset import (      construct_image_from_string, -    EmnistLinesDataLoaders,      EmnistLinesDataset,      get_samples_by_character,  ) -from .util import Transpose +from .util import fetch_data_loaders, Transpose  __all__ = [ -    "_augment_emnist_mapping", -    "_load_emnist_essentials",      "construct_image_from_string",      "DATA_DIRNAME",      "EmnistDataset", -    "EmnistDataLoaders", -    "EmnistLinesDataLoaders", +    "EmnistMapper",      "EmnistLinesDataset", +    "fetch_data_loaders",      "get_samples_by_character",      "Transpose",  ] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f3d65ee..96f84e5 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -39,45 +39,101 @@ def download_emnist() -> None:      save_emnist_essentials(dataset) -def _load_emnist_essentials() -> Dict: -    """Load the EMNIST mapping.""" -    with open(str(ESSENTIALS_FILENAME)) as f: -        essentials = json.load(f) -    return essentials - - -def _augment_emnist_mapping(mapping: Dict) -> Dict: -    """Augment the mapping with extra symbols.""" -    # Extra symbols in IAM dataset -    extra_symbols = [ -        " ", -        "!", -        '"', -        "#", -        "&", -        "'", -        "(", -        ")", -        "*", -        "+", -        ",", -        "-", -        ".", -        "/", -        ":", -        ";", -        "?", -    ] - -    # padding symbol -    extra_symbols.append("_") - -    max_key = max(mapping.keys()) -    extra_mapping = {} -    for i, symbol in enumerate(extra_symbols): -        extra_mapping[max_key + 1 + i] = symbol - -    return {**mapping, **extra_mapping} +class EmnistMapper: +    """Mapper between network output to Emnist character.""" + +    def __init__(self) -> None: +        """Loads the emnist essentials file with the mapping and input shape.""" +        self.essentials = self._load_emnist_essentials() +        # Load dataset infromation. +        self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) +        self._inverse_mapping = {v: k for k, v in self.mapping.items()} +        self._num_classes = len(self.mapping) +        self._input_shape = self.essentials["input_shape"] + +    def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: +        """Maps the token to emnist character or character index. + +        If the token is an integer (index), the method will return the Emnist character corresponding to that index. +        If the token is a str (Emnist character), the method will return the corresponding index for that character. + +        Args: +            token (Union[str, int, np.uint8]): Eihter a string or index (integer). + +        Returns: +            Union[str, int]: The mapping result. + +        Raises: +            KeyError: If the index or string does not exist in the mapping. + +        """ +        if (isinstance(token, np.uint8) or isinstance(token, int)) and int( +            token +        ) in self.mapping: +            return self.mapping[int(token)] +        elif isinstance(token, str) and token in self._inverse_mapping: +            return self._inverse_mapping[token] +        else: +            raise KeyError(f"Token {token} does not exist in the mappings.") + +    @property +    def mapping(self) -> Dict: +        """Returns the mapping between index and character.""" +        return self._mapping + +    @property +    def inverse_mapping(self) -> Dict: +        """Returns the mapping between character and index.""" +        return self._inverse_mapping + +    @property +    def num_classes(self) -> int: +        """Returns the number of classes in the dataset.""" +        return self._num_classes + +    @property +    def input_shape(self) -> List[int]: +        """Returns the input shape of the Emnist characters.""" +        return self._input_shape + +    def _load_emnist_essentials(self) -> Dict: +        """Load the EMNIST mapping.""" +        with open(str(ESSENTIALS_FILENAME)) as f: +            essentials = json.load(f) +        return essentials + +    def _augment_emnist_mapping(self, mapping: Dict) -> Dict: +        """Augment the mapping with extra symbols.""" +        # Extra symbols in IAM dataset +        extra_symbols = [ +            " ", +            "!", +            '"', +            "#", +            "&", +            "'", +            "(", +            ")", +            "*", +            "+", +            ",", +            "-", +            ".", +            "/", +            ":", +            ";", +            "?", +        ] + +        # padding symbol +        extra_symbols.append("_") + +        max_key = max(mapping.keys()) +        extra_mapping = {} +        for i, symbol in enumerate(extra_symbols): +            extra_mapping[max_key + 1 + i] = symbol + +        return {**mapping, **extra_mapping}  class EmnistDataset(Dataset): @@ -110,10 +166,12 @@ class EmnistDataset(Dataset):          self.train = train          self.sample_to_balance = sample_to_balance +          if subsample_fraction is not None:              if not 0.0 < subsample_fraction < 1.0:                  raise ValueError("The subsample fraction must be in (0, 1).")          self.subsample_fraction = subsample_fraction +          self.transform = transform          if self.transform is None:              self.transform = Compose([Transpose(), ToTensor()]) @@ -121,17 +179,22 @@ class EmnistDataset(Dataset):          self.target_transform = target_transform          self.seed = seed -        # Load dataset infromation. -        essentials = _load_emnist_essentials() -        self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) -        self.inverse_mapping = {v: k for k, v in self.mapping.items()} -        self.num_classes = len(self.mapping) -        self.input_shape = essentials["input_shape"] +        self._mapper = EmnistMapper() +        self.input_shape = self._mapper.input_shape +        self.num_classes = self._mapper.num_classes          # Placeholders          self.data = None          self.targets = None +        # Load dataset. +        self.load_emnist_dataset() + +    @property +    def mapper(self) -> EmnistMapper: +        """Returns the EmnistMapper.""" +        return self._mapper +      def __len__(self) -> int:          """Returns the length of the dataset."""          return len(self.data) @@ -162,13 +225,18 @@ class EmnistDataset(Dataset):          return data, targets +    @property +    def __name__(self) -> str: +        """Returns the name of the dataset.""" +        return "EmnistDataset" +      def __repr__(self) -> str:          """Returns information about the dataset."""          return (              "EMNIST Dataset\n"              f"Num classes: {self.num_classes}\n" -            f"Mapping: {self.mapping}\n"              f"Input shape: {self.input_shape}\n" +            f"Mapping: {self.mapper.mapping}\n"          )      def _sample_to_balance(self) -> None: @@ -217,118 +285,3 @@ class EmnistDataset(Dataset):          if self.subsample_fraction is not None:              self._subsample() - - -class EmnistDataLoaders: -    """Class for Emnist DataLoaders.""" - -    def __init__( -        self, -        splits: List[str], -        sample_to_balance: bool = False, -        subsample_fraction: float = None, -        transform: Optional[Callable] = None, -        target_transform: Optional[Callable] = None, -        batch_size: int = 128, -        shuffle: bool = False, -        num_workers: int = 0, -        cuda: bool = True, -        seed: int = 4711, -    ) -> None: -        """Fetches DataLoaders for given split(s). - -        Args: -            splits (List[str]): One or both of the dataset splits "train" and "val". -            sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. -                Defaults to False. -            subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire -                dataset will be loaded. -            transform (Optional[Callable]):  A function/transform that takes in an PIL image and returns a -                transformed version. E.g, transforms.RandomCrop. Defaults to None. -            target_transform (Optional[Callable]): A function/transform that takes in the target and -                transforms it. Defaults to None. -            batch_size (int): How many samples per batch to load. Defaults to 128. -            shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. -            num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be -                loaded in the main process. Defaults to 0. -            cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning -                them. Defaults to True. -            seed (int): Seed for sampling. - -        Raises: -            ValueError: If subsample_fraction is not None and outside the range (0, 1). - -        """ -        self.splits = splits - -        if subsample_fraction is not None: -            if not 0.0 < subsample_fraction < 1.0: -                raise ValueError("The subsample fraction must be in (0, 1).") - -        self.dataset_args = { -            "sample_to_balance": sample_to_balance, -            "subsample_fraction": subsample_fraction, -            "transform": transform, -            "target_transform": target_transform, -            "seed": seed, -        } -        self.batch_size = batch_size -        self.shuffle = shuffle -        self.num_workers = num_workers -        self.cuda = cuda -        self._data_loaders = self._load_data_loaders() - -    def __repr__(self) -> str: -        """Returns information about the dataset.""" -        return self._data_loaders[self.splits[0]].dataset.__repr__() - -    @property -    def __name__(self) -> str: -        """Returns the name of the dataset.""" -        return "Emnist" - -    def __call__(self, split: str) -> DataLoader: -        """Returns the `split` DataLoader. - -        Args: -            split (str): The dataset split, i.e. train or val. - -        Returns: -            DataLoader: A PyTorch DataLoader. - -        Raises: -            ValueError: If the split does not exist. - -        """ -        try: -            return self._data_loaders[split] -        except KeyError: -            raise ValueError(f"Split {split} does not exist.") - -    def _load_data_loaders(self) -> Dict[str, DataLoader]: -        """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" -        data_loaders = {} - -        for split in ["train", "val"]: -            if split in self.splits: - -                if split == "train": -                    self.dataset_args["train"] = True -                else: -                    self.dataset_args["train"] = False - -                emnist_dataset = EmnistDataset(**self.dataset_args) - -                emnist_dataset.load_emnist_dataset() - -                data_loader = DataLoader( -                    dataset=emnist_dataset, -                    batch_size=self.batch_size, -                    shuffle=self.shuffle, -                    num_workers=self.num_workers, -                    pin_memory=self.cuda, -                ) - -                data_loaders[split] = data_loader - -        return data_loaders diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 1c6e959..d64a991 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset  from torchvision.transforms import Compose, Normalize, ToTensor  from text_recognizer.datasets import ( -    _augment_emnist_mapping, -    _load_emnist_essentials,      DATA_DIRNAME,      EmnistDataset, +    EmnistMapper,      ESSENTIALS_FILENAME,  )  from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset):      def __init__(          self,          train: bool = False, -        emnist: Optional[EmnistDataset] = None,          transform: Optional[Callable] = None,          target_transform: Optional[Callable] = None,          max_length: int = 34, @@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset):          num_samples: int = 10000,          seed: int = 4711,      ) -> None: -        """Short summary. +        """Set attributes and loads the dataset.          Args: -            emnist (EmnistDataset): A EmnistDataset object.              train (bool): Flag for the filename. Defaults to False. Defaults to None.              transform (Optional[Callable]): The transform of the data. Defaults to None.              target_transform (Optional[Callable]): The transform of the target. Defaults to None. @@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset):          """          self.train = train -        self.emnist = emnist          self.transform = transform          if self.transform is None: @@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset):          if self.target_transform is None:              self.target_transform = torch.tensor -        # Load emnist dataset infromation. -        essentials = _load_emnist_essentials() -        self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) -        self.num_classes = len(self.mapping) -        self.input_shape = essentials["input_shape"] +        # Extract dataset information. +        self._mapper = EmnistMapper() +        self.input_shape = self._mapper.input_shape +        self.num_classes = self._mapper.num_classes          self.max_length = max_length          self.min_overlap = min_overlap @@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset):          self.output_shape = (self.max_length, self.num_classes)          self.seed = seed -        # Placeholders for the generated dataset. +        # Placeholders for the dataset.          self.data = None          self.target = None +        # Load dataset. +        self._load_or_generate_data() +      def __len__(self) -> int:          """Returns the length of the dataset."""          return len(self.data) @@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset):          if torch.is_tensor(index):              index = index.tolist() -        # data = np.array([self.data[index]])          data = self.data[index]          targets = self.targets[index] @@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset):          return data, targets +    @property +    def __name__(self) -> str: +        """Returns the name of the dataset.""" +        return "EmnistLinesDataset" +      def __repr__(self) -> str:          """Returns information about the dataset."""          return ( @@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset):          )      @property +    def mapper(self) -> EmnistMapper: +        """Returns the EmnistMapper.""" +        return self._mapper + +    @property      def data_filename(self) -> Path:          """Path to the h5 file."""          filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset):          sentence_generator = SentenceGenerator(self.max_length)          # Load emnist dataset. -        self.emnist.load_emnist_dataset() +        emnist = EmnistDataset(train=self.train, sample_to_balance=True) +          samples_by_character = get_samples_by_character( -            self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, +            emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,          )          DATA_DIRNAME.mkdir(parents=True, exist_ok=True) @@ -332,94 +340,3 @@ def create_datasets(              num_samples=num,          )          emnist_lines._load_or_generate_data() - - -class EmnistLinesDataLoaders: -    """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" - -    def __init__( -        self, -        splits: List[str], -        max_length: int = 34, -        min_overlap: float = 0, -        max_overlap: float = 0.33, -        num_samples: int = 10000, -        transform: Optional[Callable] = None, -        target_transform: Optional[Callable] = None, -        batch_size: int = 128, -        shuffle: bool = False, -        num_workers: int = 0, -        cuda: bool = True, -        seed: int = 4711, -    ) -> None: -        """Sets the data loader arguments.""" -        self.splits = splits -        self.dataset_args = { -            "max_length": max_length, -            "min_overlap": min_overlap, -            "max_overlap": max_overlap, -            "num_samples": num_samples, -            "transform": transform, -            "target_transform": target_transform, -            "seed": seed, -        } -        self.batch_size = batch_size -        self.shuffle = shuffle -        self.num_workers = num_workers -        self.cuda = cuda -        self._data_loaders = self._load_data_loaders() - -    def __repr__(self) -> str: -        """Returns information about the dataset.""" -        return self._data_loaders[self.splits[0]].dataset.__repr__() - -    @property -    def __name__(self) -> str: -        """Returns the name of the dataset.""" -        return "EmnistLines" - -    def __call__(self, split: str) -> DataLoader: -        """Returns the `split` DataLoader. - -        Args: -            split (str): The dataset split, i.e. train or val. - -        Returns: -            DataLoader: A PyTorch DataLoader. - -        Raises: -            ValueError: If the split does not exist. - -        """ -        try: -            return self._data_loaders[split] -        except KeyError: -            raise ValueError(f"Split {split} does not exist.") - -    def _load_data_loaders(self) -> Dict[str, DataLoader]: -        """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" -        data_loaders = {} - -        for split in ["train", "val"]: -            if split in self.splits: - -                if split == "train": -                    self.dataset_args["train"] = True -                else: -                    self.dataset_args["train"] = False - -                emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) - -                emnist_lines_dataset._load_or_generate_data() - -                data_loader = DataLoader( -                    dataset=emnist_lines_dataset, -                    batch_size=self.batch_size, -                    shuffle=self.shuffle, -                    num_workers=self.num_workers, -                    pin_memory=self.cuda, -                ) - -                data_loaders[split] = data_loader - -        return data_loaders diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 6668eef..321bc67 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,6 +1,10 @@  """Util functions for datasets.""" +import importlib +from typing import Callable, Dict, List, Type +  import numpy as np  from PIL import Image +from torch.utils.data import DataLoader, Dataset  class Transpose: @@ -9,3 +13,59 @@ class Transpose:      def __call__(self, image: Image) -> np.ndarray:          """Swaps axis."""          return np.array(image).swapaxes(0, 1) + + +def fetch_data_loaders( +    splits: List[str], +    dataset: str, +    dataset_args: Dict, +    batch_size: int = 128, +    shuffle: bool = False, +    num_workers: int = 0, +    cuda: bool = True, +) -> Dict[str, DataLoader]: +    """Fetches DataLoaders for given split(s) as a dictionary. + +    Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then +    calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. + +    Args: +        splits (List[str]): One or both of the dataset splits "train" and "val". +        dataset (str): The name of the dataset. +        dataset_args (Dict): The dataset arguments. +        batch_size (int): How many samples per batch to load. Defaults to 128. +        shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. +        num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be +            loaded in the main process. Defaults to 0. +        cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning +            them. Defaults to True. + +    Returns: +        Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. + +    """ + +    def check_dataset_args(args: Dict, split: str) -> Dict: +        args["train"] = True if split == "train" else False +        return args + +    # Import dataset module. +    datasets_module = importlib.import_module("text_recognizer.datasets") +    dataset_ = getattr(datasets_module, dataset) + +    data_loaders = {} + +    for split in ["train", "val"]: +        if split in splits: + +            data_loader = DataLoader( +                dataset=dataset_(**check_dataset_args(dataset_args, split)), +                batch_size=batch_size, +                shuffle=shuffle, +                num_workers=num_workers, +                pin_memory=cuda, +            ) + +            data_loaders[split] = data_loader + +    return data_loaders diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 84a86ca..6d40b49 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -12,6 +12,7 @@ import torch  from torch import nn  from torchsummary import summary +from text_recognizer.datasets import EmnistMapper, fetch_data_loaders  WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -23,7 +24,6 @@ class Model(ABC):          self,          network_fn: Type[nn.Module],          network_args: Optional[Dict] = None, -        data_loader: Optional[Callable] = None,          data_loader_args: Optional[Dict] = None,          metrics: Optional[Dict] = None,          criterion: Optional[Callable] = None, @@ -39,7 +39,6 @@ class Model(ABC):          Args:              network_fn (Type[nn.Module]): The PyTorch network.              network_args (Optional[Dict]): Arguments for the network. Defaults to None. -            data_loader (Optional[Callable]): A function that fetches train and val DataLoader.              data_loader_args (Optional[Dict]):  Arguments for the DataLoader.              metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.              criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. @@ -54,15 +53,11 @@ class Model(ABC):          """ -        # Fetch data loaders. -        if data_loader_args is not None: -            self._data_loaders = data_loader(**data_loader_args) -            dataset_name = self._data_loaders.__name__ -            self._mapping = self._data_loaders.mapping -        else: -            self._mapping = None -            dataset_name = "*" -            self._data_loaders = None +        # Fetch data loaders and dataset info. +        dataset_name, self._data_loaders, self._mapper = self._load_data_loader( +            data_loader_args +        ) +        self._input_shape = self._mapper.input_shape          self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" @@ -76,40 +71,15 @@ class Model(ABC):              self._device = device          # Load network. -        self._network = None -        self._network_args = network_args -        # If no network arguemnts are given, load pretrained weights if they exist. -        if self._network_args is None: -            self.load_weights(network_fn) -        else: -            self._network = network_fn(**self._network_args) +        self._network, self._network_args = self._load_network(network_fn, network_args)          # To device.          self._network.to(self._device) -        # Set criterion. -        self._criterion = None -        if criterion is not None: -            self._criterion = criterion(**criterion_args) - -        # Set optimizer. -        self._optimizer = None -        if optimizer is not None: -            self._optimizer = optimizer(self._network.parameters(), **optimizer_args) - -        # Set learning rate scheduler. -        self._lr_scheduler = None -        if lr_scheduler is not None: -            # OneCycleLR needs the number of steps in an epoch as an input argument. -            if "OneCycleLR" in str(lr_scheduler): -                lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) -            self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) - -        # Extract the input shape for the torchsummary. -        if isinstance(self._network_args["input_size"], int): -            self._input_shape = (1,) + tuple([self._network_args["input_size"]]) -        else: -            self._input_shape = (1,) + tuple(self._network_args["input_size"]) +        # Set training objects. +        self._criterion = self._load_criterion(criterion, criterion_args) +        self._optimizer = self._load_optimizer(optimizer, optimizer_args) +        self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)          # Experiment directory.          self.model_dir = None @@ -117,6 +87,64 @@ class Model(ABC):          # Flag for stopping training.          self.stop_training = False +    def _load_data_loader( +        self, data_loader_args: Optional[Dict] +    ) -> Tuple[str, Dict, EmnistMapper]: +        """Loads data loader, dataset name, and dataset mapper.""" +        if data_loader_args is not None: +            data_loaders = fetch_data_loaders(**data_loader_args) +            dataset = list(data_loaders.values())[0].dataset +            dataset_name = dataset.__name__ +            mapper = dataset.mapper +        else: +            self._mapper = EmnistMapper() +            dataset_name = "*" +            data_loaders = None +        return dataset_name, data_loaders, mapper + +    def _load_network( +        self, network_fn: Type[nn.Module], network_args: Optional[Dict] +    ) -> Tuple[Type[nn.Module], Dict]: +        """Loads the network.""" +        # If no network arguemnts are given, load pretrained weights if they exist. +        if network_args is None: +            network, network_args = self.load_weights(network_fn) +        else: +            network = network_fn(**network_args) +        return network, network_args + +    def _load_criterion( +        self, criterion: Optional[Callable], criterion_args: Optional[Dict] +    ) -> Optional[Callable]: +        """Loads the criterion.""" +        if criterion is not None: +            _criterion = criterion(**criterion_args) +        else: +            _criterion = None +        return _criterion + +    def _load_optimizer( +        self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] +    ) -> Optional[Callable]: +        """Loads the optimizer.""" +        if optimizer is not None: +            _optimizer = optimizer(self._network.parameters(), **optimizer_args) +        else: +            _optimizer = None +        return _optimizer + +    def _load_lr_scheduler( +        self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] +    ) -> Optional[Callable]: +        """Loads learning rate scheduler.""" +        if self._optimizer and lr_scheduler is not None: +            if "OneCycleLR" in str(lr_scheduler): +                lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) +            _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) +        else: +            _lr_scheduler = None +        return _lr_scheduler +      @property      def __name__(self) -> str:          """Returns the name of the model.""" @@ -128,9 +156,14 @@ class Model(ABC):          return self._input_shape      @property +    def mapper(self) -> Dict: +        """Returns the mapper that maps between ints and chars.""" +        return self._mapper + +    @property      def mapping(self) -> Dict: -        """Returns the class mapping.""" -        return self._mapping +        """Returns the mapping between network output and Emnist character.""" +        return self._mapper.mapping      def eval(self) -> None:          """Sets the network to evaluation mode.""" @@ -184,7 +217,11 @@ class Model(ABC):      def summary(self) -> None:          """Prints a summary of the network architecture."""          device = re.sub("[^A-Za-z]+", "", self.device) -        summary(self._network, self._input_shape, device=device) +        if self._input_shape is not None: +            input_shape = (1,) + tuple(self._input_shape) +            summary(self._network, input_shape, device=device) +        else: +            logger.warning("Could not print summary as input shape is not set.")      def _get_state_dict(self) -> Dict:          """Get the state dict of the model.""" @@ -218,8 +255,9 @@ class Model(ABC):          if self._optimizer is not None:              self._optimizer.load_state_dict(checkpoint["optimizer_state"]) -        if self._lr_scheduler is not None: -            self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) +        # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs. +        # if self._lr_scheduler is not None: +        #     self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])          epoch = checkpoint["epoch"] @@ -257,7 +295,7 @@ class Model(ABC):              )              shutil.copyfile(filepath, str(self.model_dir / "best.pt")) -    def load_weights(self, network_fn: Type[nn.Module]) -> None: +    def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:          """Load the network weights."""          logger.debug("Loading network with pretrained weights.")          filename = glob(self.weights_filename)[0] @@ -267,12 +305,13 @@ class Model(ABC):              )          # Loading state directory.          state_dict = torch.load(filename, map_location=torch.device(self._device)) -        self._network_args = state_dict["network_args"] +        network_args = state_dict["network_args"]          weights = state_dict["model_state"]          # Initializes the network with trained weights. -        self._network = network_fn(**self._network_args) -        self._network.load_state_dict(weights) +        network = network_fn(**self._network_args) +        network.load_state_dict(weights) +        return network, network_args      def save_weights(self, path: Path) -> None:          """Save the network weights.""" diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index f1dabb7..0a0ab2d 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -6,10 +6,6 @@ import torch  from torch import nn  from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import ( -    _augment_emnist_mapping, -    _load_emnist_essentials, -)  from text_recognizer.models.base import Model @@ -20,7 +16,6 @@ class CharacterModel(Model):          self,          network_fn: Type[nn.Module],          network_args: Optional[Dict] = None, -        data_loader: Optional[Callable] = None,          data_loader_args: Optional[Dict] = None,          metrics: Optional[Dict] = None,          criterion: Optional[Callable] = None, @@ -36,7 +31,6 @@ class CharacterModel(Model):          super().__init__(              network_fn,              network_args, -            data_loader,              data_loader_args,              metrics,              criterion, @@ -47,16 +41,9 @@ class CharacterModel(Model):              lr_scheduler_args,              device,          ) -        if self.mapping is None: -            self.load_mapping()          self.tensor_transform = ToTensor()          self.softmax = nn.Softmax(dim=0) -    def load_mapping(self) -> None: -        """Mapping between integers and classes.""" -        essentials = _load_emnist_essentials() -        self._mapping = _augment_emnist_mapping(dict(essentials["mapping"])) -      def predict_on_image(          self, image: Union[np.ndarray, torch.Tensor]      ) -> Tuple[str, float]: @@ -86,6 +73,6 @@ class CharacterModel(Model):          index = int(torch.argmax(prediction, dim=0))          confidence_of_prediction = prediction[index] -        predicted_character = self._mapping[index] +        predicted_character = self.mapper(index)          return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py new file mode 100644 index 0000000..00ad47e --- /dev/null +++ b/src/text_recognizer/networks/ctc.py @@ -0,0 +1,10 @@ +"""Decodes the CTC output.""" +# +# from typing import Tuple +# import torch +# +# +# def greedy_decoder( +#     output, labels, label_length, blank_label, collapse_repeated=True +# ) -> Tuple[torch.Tensor, torch.Tensor]: +#     pass diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 2839a0c..cbc58fc 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,24 +1,16 @@  """Defines the LeNet network."""  from typing import Callable, Dict, Optional, Tuple +from einops.layers.torch import Rearrange  import torch  from torch import nn -class Flatten(nn.Module): -    """Flattens a tensor.""" - -    def forward(self, x: int) -> torch.Tensor: -        """Flattens a tensor for input to a nn.Linear layer.""" -        return torch.flatten(x, start_dim=1) - -  class LeNet(nn.Module):      """LeNet network."""      def __init__(          self, -        input_size: Tuple[int, ...] = (1, 28, 28),          channels: Tuple[int, ...] = (1, 32, 64),          kernel_sizes: Tuple[int, ...] = (3, 3, 2),          hidden_size: Tuple[int, ...] = (9216, 128), @@ -30,7 +22,6 @@ class LeNet(nn.Module):          """The LeNet network.          Args: -            input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28).              channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).              kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).              hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. @@ -44,10 +35,9 @@ class LeNet(nn.Module):          """          super().__init__() -        self._input_size = input_size -          if activation_fn is not None: -            activation_fn = getattr(nn, activation_fn)(activation_fn_args) +            activation_fn_args = activation_fn_args or {} +            activation_fn = getattr(nn, activation_fn)(**activation_fn_args)          else:              activation_fn = nn.ReLU(inplace=True) @@ -66,7 +56,7 @@ class LeNet(nn.Module):              activation_fn,              nn.MaxPool2d(kernel_sizes[2]),              nn.Dropout(p=dropout_rate), -            Flatten(), +            Rearrange("b c h w -> b (c h w)"),              nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),              activation_fn,              nn.Dropout(p=dropout_rate), @@ -77,6 +67,7 @@ class LeNet(nn.Module):      def forward(self, x: torch.Tensor) -> torch.Tensor:          """The feedforward.""" +        # If batch dimenstion is missing, it needs to be added.          if len(x.shape) == 3:              x = x.unsqueeze(0)          return self.layers(x) diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py new file mode 100644 index 0000000..d704139 --- /dev/null +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -0,0 +1,4 @@ +"""LSTM with CTC for handwritten text recognition within a line.""" + +import torch +from torch import nn diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py new file mode 100644 index 0000000..9440f9d --- /dev/null +++ b/src/text_recognizer/networks/misc.py @@ -0,0 +1,28 @@ +"""Miscellaneous neural network functionality.""" +from typing import Tuple + +from einops import rearrange +import torch +from torch.nn import Unfold + + +def sliding_window( +    images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: +    """Creates patches of an image. + +    Args: +        images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). +        patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. +        stride (Tuple[int, int]): The stride of the sliding window. + +    Returns: +        torch.Tensor: A tensor with the shape (batch, patches, height, width). + +    """ +    unfold = Unfold(kernel_size=patch_size, stride=stride) +    patches = unfold(images) +    patches = rearrange( +        patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1] +    ) +    return patches diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index d704d99..ac2c825 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,6 +1,7 @@  """Defines the MLP network."""  from typing import Callable, Dict, List, Optional, Union +from einops.layers.torch import Rearrange  import torch  from torch import nn @@ -34,7 +35,8 @@ class MLP(nn.Module):          super().__init__()          if activation_fn is not None: -            activation_fn = getattr(nn, activation_fn)(activation_fn_args) +            activation_fn_args = activation_fn_args or {} +            activation_fn = getattr(nn, activation_fn)(**activation_fn_args)          else:              activation_fn = nn.ReLU(inplace=True) @@ -42,6 +44,7 @@ class MLP(nn.Module):              hidden_size = [hidden_size] * num_layers          self.layers = [ +            Rearrange("b c h w -> b (c h w)"),              nn.Linear(in_features=input_size, out_features=hidden_size[0]),              activation_fn,          ] @@ -63,7 +66,9 @@ class MLP(nn.Module):      def forward(self, x: torch.Tensor) -> torch.Tensor:          """The feedforward.""" -        x = torch.flatten(x, start_dim=1) +        # If batch dimenstion is missing, it needs to be added. +        if len(x.shape) == 3: +            x = x.unsqueeze(0)          return self.layers(x)      @property diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py new file mode 100644 index 0000000..23394b0 --- /dev/null +++ b/src/text_recognizer/networks/residual_network.py @@ -0,0 +1 @@ +"""Residual CNN.""" diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt Binary files differnew file mode 100644 index 0000000..81ef9be --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt Binary files differnew file mode 100644 index 0000000..49bd166 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt Binary files differindex 46b1cb1..ed73c09 100644 --- a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt  |