From 3b06ef615a8db67a03927576e0c12fbfb2501f5f Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 14 Sep 2020 22:15:47 +0200 Subject: Fixed CTC loss. --- src/text_recognizer/datasets/dataset.py | 7 ++++--- src/text_recognizer/datasets/emnist_dataset.py | 9 +-------- .../datasets/emnist_lines_dataset.py | 9 +++++++-- src/text_recognizer/datasets/transforms.py | 13 +++++++++++++ src/text_recognizer/datasets/util.py | 2 +- src/text_recognizer/models/line_ctc_model.py | 20 +++++++++++++++++--- src/text_recognizer/networks/ctc.py | 2 +- src/text_recognizer/networks/line_lstm_ctc.py | 6 +++++- src/text_recognizer/networks/misc.py | 1 + src/text_recognizer/networks/transformer.py | 1 + src/text_recognizer/networks/wide_resnet.py | 8 ++++---- ...istLinesDataset_LineRecurrentNetwork_weights.pt | Bin 15375126 -> 61946486 bytes ...IamLinesDataset_LineRecurrentNetwork_weights.pt | Bin 0 -> 45257014 bytes 13 files changed, 55 insertions(+), 23 deletions(-) create mode 100644 src/text_recognizer/datasets/transforms.py create mode 100644 src/text_recognizer/networks/transformer.py create mode 100644 src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index f328a0f..05520e5 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -23,7 +23,7 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. @@ -31,6 +31,7 @@ class Dataset(data.Dataset): ValueError: If subsample_fraction is not None and outside the range (0, 1). """ + self.train = train self.split = "train" if self.train else "test" @@ -96,8 +97,8 @@ class Dataset(data.Dataset): if self.subsample_fraction is None: return num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_subsample] - self.targets = self.targets[:num_subsample] + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] def __len__(self) -> int: """Returns the length of the dataset.""" diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 81268fb..d01dcee 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -13,17 +13,10 @@ from torchvision.datasets import EMNIST from torchvision.transforms import Compose, ToTensor from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.transforms import Transpose from text_recognizer.datasets.util import DATA_DIRNAME -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - class EmnistDataset(Dataset): """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 8fa77cd..6268a01 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -19,7 +19,6 @@ from text_recognizer.datasets.util import ( EmnistMapper, ESSENTIALS_FILENAME, ) -from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -32,6 +31,7 @@ class EmnistLinesDataset(Dataset): train: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + subsample_fraction: float = None, max_length: int = 34, min_overlap: float = 0, max_overlap: float = 0.33, @@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset): train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. max_length (int): The maximum number of characters. Defaults to 34. min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. @@ -52,7 +53,10 @@ class EmnistLinesDataset(Dataset): """ super().__init__( - train=train, transform=transform, target_transform=target_transform, + train=train, + transform=transform, + target_transform=target_transform, + subsample_fraction=subsample_fraction, ) # Extract dataset information. @@ -128,6 +132,7 @@ class EmnistLinesDataset(Dataset): if not self.data_filename.exists(): self._generate_data() self._load_data() + self._subsample() def _load_data(self) -> None: """Loads the dataset from the h5 file.""" diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py new file mode 100644 index 0000000..17231a8 --- /dev/null +++ b/src/text_recognizer/datasets/transforms.py @@ -0,0 +1,13 @@ +"""Transforms for PyTorch datasets.""" +import numpy as np +from PIL import Image +import torch +from torch import Tensor + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 3acf5db..73968a1 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -126,7 +126,7 @@ class EmnistMapper: "?", ] - # padding symbol + # padding symbol, and acts as blank symbol as well. extra_symbols.append("_") max_key = max(mapping.keys()) diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py index 97308a7..af41f18 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/line_ctc_model.py @@ -62,12 +62,26 @@ class LineCTCModel(Model): Tensor: The CTC loss. """ + + # Input lengths on the form [T, B] input_lengths = torch.full( size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, ) - target_lengths = torch.full( - size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 79] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) ) + return self.criterion(output, targets, input_lengths, target_lengths) @torch.no_grad() @@ -93,7 +107,7 @@ class LineCTCModel(Model): raw_pred, _ = greedy_decoder( predictions=log_probs, character_mapper=self.mapper, - blank_label=79, + blank_label=80, collapse_repeated=True, ) diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 72f18b8..2493d5c 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -24,7 +24,7 @@ def greedy_decoder( target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults to None. - blank_label (int): The blank character to be ignored. Defaults to 79. + blank_label (int): The blank character to be ignored. Defaults to 80. collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. Returns: diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 988b615..5c57479 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module): self.hidden_size = hidden_size self.encoder = self._configure_encoder(encoder) self.flatten = flatten + self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size) self.rnn = nn.LSTM( - input_size=self.input_size, + input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=num_layers, ) @@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module): # Avgerage pooling. x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x + # Linear layer between CNN and RNN + x = self.fc(x) + # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index cac9e78..1f853e9 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -34,6 +34,7 @@ def activation_function(activation: str) -> Type[nn.Module]: """Returns the callable activation function.""" activation_fns = nn.ModuleDict( [ + ["elu", nn.ELU(inplace=True)], ["gelu", nn.GELU()], ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], ["none", nn.Identity()], diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/text_recognizer/networks/transformer.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index d1c8f9a..618f414 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -28,10 +28,10 @@ def conv_init(module: Type[nn.Module]) -> None: classname = module.__class__.__name__ if classname.find("Conv") != -1: nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) - nn.init.constant(module.bias, 0) + nn.init.constant_(module.bias, 0) elif classname.find("BatchNorm") != -1: - nn.init.constant(module.weight, 1) - nn.init.constant(module.bias, 0) + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) class WideBlock(nn.Module): @@ -183,7 +183,7 @@ class WideResidualNetwork(nn.Module): else None ) - self.apply(conv_init) + # self.apply(conv_init) def _configure_wide_layer( self, in_planes: int, out_planes: int, stride: int, activation: str diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt index 9aec6ae..59c06c2 100644 Binary files a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt and b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt differ diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt new file mode 100644 index 0000000..9bd8ca2 Binary files /dev/null and b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt differ -- cgit v1.2.3-70-g09d2