summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py13
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py275
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py129
-rw-r--r--src/text_recognizer/datasets/util.py60
-rw-r--r--src/text_recognizer/models/base.py139
-rw-r--r--src/text_recognizer/models/character_model.py15
-rw-r--r--src/text_recognizer/networks/ctc.py10
-rw-r--r--src/text_recognizer/networks/lenet.py19
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py4
-rw-r--r--src/text_recognizer/networks/misc.py28
-rw-r--r--src/text_recognizer/networks/mlp.py9
-rw-r--r--src/text_recognizer/networks/residual_network.py1
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.ptbin0 -> 14485310 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.ptbin0 -> 1704174 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.ptbin14485305 -> 14485342 bytes
15 files changed, 346 insertions, 356 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 1b4cc59..05f74f6 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,29 +1,24 @@
"""Dataset modules."""
from .emnist_dataset import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
DATA_DIRNAME,
- EmnistDataLoaders,
EmnistDataset,
+ EmnistMapper,
ESSENTIALS_FILENAME,
)
from .emnist_lines_dataset import (
construct_image_from_string,
- EmnistLinesDataLoaders,
EmnistLinesDataset,
get_samples_by_character,
)
-from .util import Transpose
+from .util import fetch_data_loaders, Transpose
__all__ = [
- "_augment_emnist_mapping",
- "_load_emnist_essentials",
"construct_image_from_string",
"DATA_DIRNAME",
"EmnistDataset",
- "EmnistDataLoaders",
- "EmnistLinesDataLoaders",
+ "EmnistMapper",
"EmnistLinesDataset",
+ "fetch_data_loaders",
"get_samples_by_character",
"Transpose",
]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index f3d65ee..96f84e5 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -39,45 +39,101 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def _load_emnist_essentials() -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
-
-def _augment_emnist_mapping(mapping: Dict) -> Dict:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol
- extra_symbols.append("_")
-
- max_key = max(mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- return {**mapping, **extra_mapping}
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(self) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset infromation.
+ self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
+ token
+ ) in self.mapping:
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol
+ extra_symbols.append("_")
+
+ max_key = max(mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ return {**mapping, **extra_mapping}
class EmnistDataset(Dataset):
@@ -110,10 +166,12 @@ class EmnistDataset(Dataset):
self.train = train
self.sample_to_balance = sample_to_balance
+
if subsample_fraction is not None:
if not 0.0 < subsample_fraction < 1.0:
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
+
self.transform = transform
if self.transform is None:
self.transform = Compose([Transpose(), ToTensor()])
@@ -121,17 +179,22 @@ class EmnistDataset(Dataset):
self.target_transform = target_transform
self.seed = seed
- # Load dataset infromation.
- essentials = _load_emnist_essentials()
- self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- self.inverse_mapping = {v: k for k, v in self.mapping.items()}
- self.num_classes = len(self.mapping)
- self.input_shape = essentials["input_shape"]
+ self._mapper = EmnistMapper()
+ self.input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
# Placeholders
self.data = None
self.targets = None
+ # Load dataset.
+ self.load_emnist_dataset()
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -162,13 +225,18 @@ class EmnistDataset(Dataset):
return data, targets
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistDataset"
+
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
"EMNIST Dataset\n"
f"Num classes: {self.num_classes}\n"
- f"Mapping: {self.mapping}\n"
f"Input shape: {self.input_shape}\n"
+ f"Mapping: {self.mapper.mapping}\n"
)
def _sample_to_balance(self) -> None:
@@ -217,118 +285,3 @@ class EmnistDataset(Dataset):
if self.subsample_fraction is not None:
self._subsample()
-
-
-class EmnistDataLoaders:
- """Class for Emnist DataLoaders."""
-
- def __init__(
- self,
- splits: List[str],
- sample_to_balance: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
- seed: int = 4711,
- ) -> None:
- """Fetches DataLoaders for given split(s).
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
- Defaults to False.
- subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire
- dataset will be loaded.
- transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
- transformed version. E.g, transforms.RandomCrop. Defaults to None.
- target_transform (Optional[Callable]): A function/transform that takes in the target and
- transforms it. Defaults to None.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
- seed (int): Seed for sampling.
-
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
- """
- self.splits = splits
-
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
-
- self.dataset_args = {
- "sample_to_balance": sample_to_balance,
- "subsample_fraction": subsample_fraction,
- "transform": transform,
- "target_transform": target_transform,
- "seed": seed,
- }
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.cuda = cuda
- self._data_loaders = self._load_data_loaders()
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return self._data_loaders[self.splits[0]].dataset.__repr__()
-
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "Emnist"
-
- def __call__(self, split: str) -> DataLoader:
- """Returns the `split` DataLoader.
-
- Args:
- split (str): The dataset split, i.e. train or val.
-
- Returns:
- DataLoader: A PyTorch DataLoader.
-
- Raises:
- ValueError: If the split does not exist.
-
- """
- try:
- return self._data_loaders[split]
- except KeyError:
- raise ValueError(f"Split {split} does not exist.")
-
- def _load_data_loaders(self) -> Dict[str, DataLoader]:
- """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in self.splits:
-
- if split == "train":
- self.dataset_args["train"] = True
- else:
- self.dataset_args["train"] = False
-
- emnist_dataset = EmnistDataset(**self.dataset_args)
-
- emnist_dataset.load_emnist_dataset()
-
- data_loader = DataLoader(
- dataset=emnist_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- num_workers=self.num_workers,
- pin_memory=self.cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 1c6e959..d64a991 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
from text_recognizer.datasets import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
DATA_DIRNAME,
EmnistDataset,
+ EmnistMapper,
ESSENTIALS_FILENAME,
)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
@@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset):
def __init__(
self,
train: bool = False,
- emnist: Optional[EmnistDataset] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
max_length: int = 34,
@@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset):
num_samples: int = 10000,
seed: int = 4711,
) -> None:
- """Short summary.
+ """Set attributes and loads the dataset.
Args:
- emnist (EmnistDataset): A EmnistDataset object.
train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
@@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset):
"""
self.train = train
- self.emnist = emnist
self.transform = transform
if self.transform is None:
@@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset):
if self.target_transform is None:
self.target_transform = torch.tensor
- # Load emnist dataset infromation.
- essentials = _load_emnist_essentials()
- self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- self.num_classes = len(self.mapping)
- self.input_shape = essentials["input_shape"]
+ # Extract dataset information.
+ self._mapper = EmnistMapper()
+ self.input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
self.max_length = max_length
self.min_overlap = min_overlap
@@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset):
self.output_shape = (self.max_length, self.num_classes)
self.seed = seed
- # Placeholders for the generated dataset.
+ # Placeholders for the dataset.
self.data = None
self.target = None
+ # Load dataset.
+ self._load_or_generate_data()
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset):
if torch.is_tensor(index):
index = index.tolist()
- # data = np.array([self.data[index]])
data = self.data[index]
targets = self.targets[index]
@@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset):
return data, targets
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistLinesDataset"
+
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
@@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset):
)
@property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
@@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- self.emnist.load_emnist_dataset()
+ emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+
samples_by_character = get_samples_by_character(
- self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping,
+ emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
)
DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
@@ -332,94 +340,3 @@ def create_datasets(
num_samples=num,
)
emnist_lines._load_or_generate_data()
-
-
-class EmnistLinesDataLoaders:
- """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset."""
-
- def __init__(
- self,
- splits: List[str],
- max_length: int = 34,
- min_overlap: float = 0,
- max_overlap: float = 0.33,
- num_samples: int = 10000,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
- seed: int = 4711,
- ) -> None:
- """Sets the data loader arguments."""
- self.splits = splits
- self.dataset_args = {
- "max_length": max_length,
- "min_overlap": min_overlap,
- "max_overlap": max_overlap,
- "num_samples": num_samples,
- "transform": transform,
- "target_transform": target_transform,
- "seed": seed,
- }
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.cuda = cuda
- self._data_loaders = self._load_data_loaders()
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return self._data_loaders[self.splits[0]].dataset.__repr__()
-
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistLines"
-
- def __call__(self, split: str) -> DataLoader:
- """Returns the `split` DataLoader.
-
- Args:
- split (str): The dataset split, i.e. train or val.
-
- Returns:
- DataLoader: A PyTorch DataLoader.
-
- Raises:
- ValueError: If the split does not exist.
-
- """
- try:
- return self._data_loaders[split]
- except KeyError:
- raise ValueError(f"Split {split} does not exist.")
-
- def _load_data_loaders(self) -> Dict[str, DataLoader]:
- """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders."""
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in self.splits:
-
- if split == "train":
- self.dataset_args["train"] = True
- else:
- self.dataset_args["train"] = False
-
- emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args)
-
- emnist_lines_dataset._load_or_generate_data()
-
- data_loader = DataLoader(
- dataset=emnist_lines_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- num_workers=self.num_workers,
- pin_memory=self.cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 6668eef..321bc67 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,6 +1,10 @@
"""Util functions for datasets."""
+import importlib
+from typing import Callable, Dict, List, Type
+
import numpy as np
from PIL import Image
+from torch.utils.data import DataLoader, Dataset
class Transpose:
@@ -9,3 +13,59 @@ class Transpose:
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
+
+
+def fetch_data_loaders(
+ splits: List[str],
+ dataset: str,
+ dataset_args: Dict,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+) -> Dict[str, DataLoader]:
+ """Fetches DataLoaders for given split(s) as a dictionary.
+
+ Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then
+ calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value.
+
+ Args:
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ dataset (str): The name of the dataset.
+ dataset_args (Dict): The dataset arguments.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
+
+ Returns:
+ Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value.
+
+ """
+
+ def check_dataset_args(args: Dict, split: str) -> Dict:
+ args["train"] = True if split == "train" else False
+ return args
+
+ # Import dataset module.
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ dataset_ = getattr(datasets_module, dataset)
+
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in splits:
+
+ data_loader = DataLoader(
+ dataset=dataset_(**check_dataset_args(dataset_args, split)),
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=num_workers,
+ pin_memory=cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 84a86ca..6d40b49 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -12,6 +12,7 @@ import torch
from torch import nn
from torchsummary import summary
+from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -23,7 +24,6 @@ class Model(ABC):
self,
network_fn: Type[nn.Module],
network_args: Optional[Dict] = None,
- data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -39,7 +39,6 @@ class Model(ABC):
Args:
network_fn (Type[nn.Module]): The PyTorch network.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
data_loader_args (Optional[Dict]): Arguments for the DataLoader.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
@@ -54,15 +53,11 @@ class Model(ABC):
"""
- # Fetch data loaders.
- if data_loader_args is not None:
- self._data_loaders = data_loader(**data_loader_args)
- dataset_name = self._data_loaders.__name__
- self._mapping = self._data_loaders.mapping
- else:
- self._mapping = None
- dataset_name = "*"
- self._data_loaders = None
+ # Fetch data loaders and dataset info.
+ dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+ data_loader_args
+ )
+ self._input_shape = self._mapper.input_shape
self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
@@ -76,40 +71,15 @@ class Model(ABC):
self._device = device
# Load network.
- self._network = None
- self._network_args = network_args
- # If no network arguemnts are given, load pretrained weights if they exist.
- if self._network_args is None:
- self.load_weights(network_fn)
- else:
- self._network = network_fn(**self._network_args)
+ self._network, self._network_args = self._load_network(network_fn, network_args)
# To device.
self._network.to(self._device)
- # Set criterion.
- self._criterion = None
- if criterion is not None:
- self._criterion = criterion(**criterion_args)
-
- # Set optimizer.
- self._optimizer = None
- if optimizer is not None:
- self._optimizer = optimizer(self._network.parameters(), **optimizer_args)
-
- # Set learning rate scheduler.
- self._lr_scheduler = None
- if lr_scheduler is not None:
- # OneCycleLR needs the number of steps in an epoch as an input argument.
- if "OneCycleLR" in str(lr_scheduler):
- lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
- self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
-
- # Extract the input shape for the torchsummary.
- if isinstance(self._network_args["input_size"], int):
- self._input_shape = (1,) + tuple([self._network_args["input_size"]])
- else:
- self._input_shape = (1,) + tuple(self._network_args["input_size"])
+ # Set training objects.
+ self._criterion = self._load_criterion(criterion, criterion_args)
+ self._optimizer = self._load_optimizer(optimizer, optimizer_args)
+ self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
# Experiment directory.
self.model_dir = None
@@ -117,6 +87,64 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
+ def _load_data_loader(
+ self, data_loader_args: Optional[Dict]
+ ) -> Tuple[str, Dict, EmnistMapper]:
+ """Loads data loader, dataset name, and dataset mapper."""
+ if data_loader_args is not None:
+ data_loaders = fetch_data_loaders(**data_loader_args)
+ dataset = list(data_loaders.values())[0].dataset
+ dataset_name = dataset.__name__
+ mapper = dataset.mapper
+ else:
+ self._mapper = EmnistMapper()
+ dataset_name = "*"
+ data_loaders = None
+ return dataset_name, data_loaders, mapper
+
+ def _load_network(
+ self, network_fn: Type[nn.Module], network_args: Optional[Dict]
+ ) -> Tuple[Type[nn.Module], Dict]:
+ """Loads the network."""
+ # If no network arguemnts are given, load pretrained weights if they exist.
+ if network_args is None:
+ network, network_args = self.load_weights(network_fn)
+ else:
+ network = network_fn(**network_args)
+ return network, network_args
+
+ def _load_criterion(
+ self, criterion: Optional[Callable], criterion_args: Optional[Dict]
+ ) -> Optional[Callable]:
+ """Loads the criterion."""
+ if criterion is not None:
+ _criterion = criterion(**criterion_args)
+ else:
+ _criterion = None
+ return _criterion
+
+ def _load_optimizer(
+ self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
+ ) -> Optional[Callable]:
+ """Loads the optimizer."""
+ if optimizer is not None:
+ _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+ else:
+ _optimizer = None
+ return _optimizer
+
+ def _load_lr_scheduler(
+ self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
+ ) -> Optional[Callable]:
+ """Loads learning rate scheduler."""
+ if self._optimizer and lr_scheduler is not None:
+ if "OneCycleLR" in str(lr_scheduler):
+ lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
+ _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+ else:
+ _lr_scheduler = None
+ return _lr_scheduler
+
@property
def __name__(self) -> str:
"""Returns the name of the model."""
@@ -128,9 +156,14 @@ class Model(ABC):
return self._input_shape
@property
+ def mapper(self) -> Dict:
+ """Returns the mapper that maps between ints and chars."""
+ return self._mapper
+
+ @property
def mapping(self) -> Dict:
- """Returns the class mapping."""
- return self._mapping
+ """Returns the mapping between network output and Emnist character."""
+ return self._mapper.mapping
def eval(self) -> None:
"""Sets the network to evaluation mode."""
@@ -184,7 +217,11 @@ class Model(ABC):
def summary(self) -> None:
"""Prints a summary of the network architecture."""
device = re.sub("[^A-Za-z]+", "", self.device)
- summary(self._network, self._input_shape, device=device)
+ if self._input_shape is not None:
+ input_shape = (1,) + tuple(self._input_shape)
+ summary(self._network, input_shape, device=device)
+ else:
+ logger.warning("Could not print summary as input shape is not set.")
def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
@@ -218,8 +255,9 @@ class Model(ABC):
if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
- if self._lr_scheduler is not None:
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
+ # if self._lr_scheduler is not None:
+ # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
epoch = checkpoint["epoch"]
@@ -257,7 +295,7 @@ class Model(ABC):
)
shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> None:
+ def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -267,12 +305,13 @@ class Model(ABC):
)
# Loading state directory.
state_dict = torch.load(filename, map_location=torch.device(self._device))
- self._network_args = state_dict["network_args"]
+ network_args = state_dict["network_args"]
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- self._network = network_fn(**self._network_args)
- self._network.load_state_dict(weights)
+ network = network_fn(**self._network_args)
+ network.load_state_dict(weights)
+ return network, network_args
def save_weights(self, path: Path) -> None:
"""Save the network weights."""
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index f1dabb7..0a0ab2d 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -6,10 +6,6 @@ import torch
from torch import nn
from torchvision.transforms import ToTensor
-from text_recognizer.datasets.emnist_dataset import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
-)
from text_recognizer.models.base import Model
@@ -20,7 +16,6 @@ class CharacterModel(Model):
self,
network_fn: Type[nn.Module],
network_args: Optional[Dict] = None,
- data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -36,7 +31,6 @@ class CharacterModel(Model):
super().__init__(
network_fn,
network_args,
- data_loader,
data_loader_args,
metrics,
criterion,
@@ -47,16 +41,9 @@ class CharacterModel(Model):
lr_scheduler_args,
device,
)
- if self.mapping is None:
- self.load_mapping()
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
- def load_mapping(self) -> None:
- """Mapping between integers and classes."""
- essentials = _load_emnist_essentials()
- self._mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
-
def predict_on_image(
self, image: Union[np.ndarray, torch.Tensor]
) -> Tuple[str, float]:
@@ -86,6 +73,6 @@ class CharacterModel(Model):
index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
- predicted_character = self._mapping[index]
+ predicted_character = self.mapper(index)
return predicted_character, confidence_of_prediction
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
new file mode 100644
index 0000000..00ad47e
--- /dev/null
+++ b/src/text_recognizer/networks/ctc.py
@@ -0,0 +1,10 @@
+"""Decodes the CTC output."""
+#
+# from typing import Tuple
+# import torch
+#
+#
+# def greedy_decoder(
+# output, labels, label_length, blank_label, collapse_repeated=True
+# ) -> Tuple[torch.Tensor, torch.Tensor]:
+# pass
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 2839a0c..cbc58fc 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,24 +1,16 @@
"""Defines the LeNet network."""
from typing import Callable, Dict, Optional, Tuple
+from einops.layers.torch import Rearrange
import torch
from torch import nn
-class Flatten(nn.Module):
- """Flattens a tensor."""
-
- def forward(self, x: int) -> torch.Tensor:
- """Flattens a tensor for input to a nn.Linear layer."""
- return torch.flatten(x, start_dim=1)
-
-
class LeNet(nn.Module):
"""LeNet network."""
def __init__(
self,
- input_size: Tuple[int, ...] = (1, 28, 28),
channels: Tuple[int, ...] = (1, 32, 64),
kernel_sizes: Tuple[int, ...] = (3, 3, 2),
hidden_size: Tuple[int, ...] = (9216, 128),
@@ -30,7 +22,6 @@ class LeNet(nn.Module):
"""The LeNet network.
Args:
- input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28).
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
@@ -44,10 +35,9 @@ class LeNet(nn.Module):
"""
super().__init__()
- self._input_size = input_size
-
if activation_fn is not None:
- activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ activation_fn_args = activation_fn_args or {}
+ activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
else:
activation_fn = nn.ReLU(inplace=True)
@@ -66,7 +56,7 @@ class LeNet(nn.Module):
activation_fn,
nn.MaxPool2d(kernel_sizes[2]),
nn.Dropout(p=dropout_rate),
- Flatten(),
+ Rearrange("b c h w -> b (c h w)"),
nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
activation_fn,
nn.Dropout(p=dropout_rate),
@@ -77,6 +67,7 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward."""
+ # If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.layers(x)
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
new file mode 100644
index 0000000..d704139
--- /dev/null
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -0,0 +1,4 @@
+"""LSTM with CTC for handwritten text recognition within a line."""
+
+import torch
+from torch import nn
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
new file mode 100644
index 0000000..9440f9d
--- /dev/null
+++ b/src/text_recognizer/networks/misc.py
@@ -0,0 +1,28 @@
+"""Miscellaneous neural network functionality."""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch.nn import Unfold
+
+
+def sliding_window(
+ images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+ """Creates patches of an image.
+
+ Args:
+ images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+ patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+ stride (Tuple[int, int]): The stride of the sliding window.
+
+ Returns:
+ torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+ """
+ unfold = Unfold(kernel_size=patch_size, stride=stride)
+ patches = unfold(images)
+ patches = rearrange(
+ patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1]
+ )
+ return patches
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index d704d99..ac2c825 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -1,6 +1,7 @@
"""Defines the MLP network."""
from typing import Callable, Dict, List, Optional, Union
+from einops.layers.torch import Rearrange
import torch
from torch import nn
@@ -34,7 +35,8 @@ class MLP(nn.Module):
super().__init__()
if activation_fn is not None:
- activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ activation_fn_args = activation_fn_args or {}
+ activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
else:
activation_fn = nn.ReLU(inplace=True)
@@ -42,6 +44,7 @@ class MLP(nn.Module):
hidden_size = [hidden_size] * num_layers
self.layers = [
+ Rearrange("b c h w -> b (c h w)"),
nn.Linear(in_features=input_size, out_features=hidden_size[0]),
activation_fn,
]
@@ -63,7 +66,9 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward."""
- x = torch.flatten(x, start_dim=1)
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
return self.layers(x)
@property
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
new file mode 100644
index 0000000..23394b0
--- /dev/null
+++ b/src/text_recognizer/networks/residual_network.py
@@ -0,0 +1 @@
+"""Residual CNN."""
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
new file mode 100644
index 0000000..81ef9be
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
new file mode 100644
index 0000000..49bd166
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
index 46b1cb1..ed73c09 100644
--- a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
Binary files differ