From 3b06ef615a8db67a03927576e0c12fbfb2501f5f Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Mon, 14 Sep 2020 22:15:47 +0200
Subject: Fixed CTC loss.

---
 src/text_recognizer/datasets/dataset.py            |   7 ++++---
 src/text_recognizer/datasets/emnist_dataset.py     |   9 +--------
 .../datasets/emnist_lines_dataset.py               |   9 +++++++--
 src/text_recognizer/datasets/transforms.py         |  13 +++++++++++++
 src/text_recognizer/datasets/util.py               |   2 +-
 src/text_recognizer/models/line_ctc_model.py       |  20 +++++++++++++++++---
 src/text_recognizer/networks/ctc.py                |   2 +-
 src/text_recognizer/networks/line_lstm_ctc.py      |   6 +++++-
 src/text_recognizer/networks/misc.py               |   1 +
 src/text_recognizer/networks/transformer.py        |   1 +
 src/text_recognizer/networks/wide_resnet.py        |   8 ++++----
 ...istLinesDataset_LineRecurrentNetwork_weights.pt | Bin 15375126 -> 61946486 bytes
 ...IamLinesDataset_LineRecurrentNetwork_weights.pt | Bin 0 -> 45257014 bytes
 13 files changed, 55 insertions(+), 23 deletions(-)
 create mode 100644 src/text_recognizer/datasets/transforms.py
 create mode 100644 src/text_recognizer/networks/transformer.py
 create mode 100644 src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt

(limited to 'src/text_recognizer')

diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index f328a0f..05520e5 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -23,7 +23,7 @@ class Dataset(data.Dataset):
 
         Args:
             train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
-            subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+            subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
             transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
             target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
 
@@ -31,6 +31,7 @@ class Dataset(data.Dataset):
             ValueError: If subsample_fraction is not None and outside the range (0, 1).
 
         """
+
         self.train = train
         self.split = "train" if self.train else "test"
 
@@ -96,8 +97,8 @@ class Dataset(data.Dataset):
         if self.subsample_fraction is None:
             return
         num_subsample = int(self.data.shape[0] * self.subsample_fraction)
-        self.data = self.data[:num_subsample]
-        self.targets = self.targets[:num_subsample]
+        self._data = self.data[:num_subsample]
+        self._targets = self.targets[:num_subsample]
 
     def __len__(self) -> int:
         """Returns the length of the dataset."""
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 81268fb..d01dcee 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -13,17 +13,10 @@ from torchvision.datasets import EMNIST
 from torchvision.transforms import Compose, ToTensor
 
 from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.transforms import Transpose
 from text_recognizer.datasets.util import DATA_DIRNAME
 
 
-class Transpose:
-    """Transposes the EMNIST image to the correct orientation."""
-
-    def __call__(self, image: Image) -> np.ndarray:
-        """Swaps axis."""
-        return np.array(image).swapaxes(0, 1)
-
-
 class EmnistDataset(Dataset):
     """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
 
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 8fa77cd..6268a01 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -19,7 +19,6 @@ from text_recognizer.datasets.util import (
     EmnistMapper,
     ESSENTIALS_FILENAME,
 )
-from text_recognizer.networks import sliding_window
 
 DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
 
@@ -32,6 +31,7 @@ class EmnistLinesDataset(Dataset):
         train: bool = False,
         transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
+        subsample_fraction: float = None,
         max_length: int = 34,
         min_overlap: float = 0,
         max_overlap: float = 0.33,
@@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset):
             train (bool): Flag for the filename. Defaults to False. Defaults to None.
             transform (Optional[Callable]): The transform of the data. Defaults to None.
             target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+            subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
             max_length (int): The maximum number of characters. Defaults to 34.
             min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
             max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
@@ -52,7 +53,10 @@ class EmnistLinesDataset(Dataset):
 
         """
         super().__init__(
-            train=train, transform=transform, target_transform=target_transform,
+            train=train,
+            transform=transform,
+            target_transform=target_transform,
+            subsample_fraction=subsample_fraction,
         )
 
         # Extract dataset information.
@@ -128,6 +132,7 @@ class EmnistLinesDataset(Dataset):
         if not self.data_filename.exists():
             self._generate_data()
         self._load_data()
+        self._subsample()
 
     def _load_data(self) -> None:
         """Loads the dataset from the h5 file."""
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
new file mode 100644
index 0000000..17231a8
--- /dev/null
+++ b/src/text_recognizer/datasets/transforms.py
@@ -0,0 +1,13 @@
+"""Transforms for PyTorch datasets."""
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+
+
+class Transpose:
+    """Transposes the EMNIST image to the correct orientation."""
+
+    def __call__(self, image: Image) -> np.ndarray:
+        """Swaps axis."""
+        return np.array(image).swapaxes(0, 1)
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 3acf5db..73968a1 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -126,7 +126,7 @@ class EmnistMapper:
             "?",
         ]
 
-        # padding symbol
+        # padding symbol, and acts as blank symbol as well.
         extra_symbols.append("_")
 
         max_key = max(mapping.keys())
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 97308a7..af41f18 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -62,12 +62,26 @@ class LineCTCModel(Model):
             Tensor: The CTC loss.
 
         """
+
+        # Input lengths on the form [T, B]
         input_lengths = torch.full(
             size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
         )
-        target_lengths = torch.full(
-            size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+
+        # Configure target tensors for ctc loss.
+        targets_ = Tensor([]).to(self.device)
+        target_lengths = []
+        for t in targets:
+            # Remove padding symbol as it acts as the blank symbol.
+            t = t[t < 79]
+            targets_ = torch.cat([targets_, t])
+            target_lengths.append(len(t))
+
+        targets = targets_.type(dtype=torch.long)
+        target_lengths = (
+            torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
         )
+
         return self.criterion(output, targets, input_lengths, target_lengths)
 
     @torch.no_grad()
@@ -93,7 +107,7 @@ class LineCTCModel(Model):
         raw_pred, _ = greedy_decoder(
             predictions=log_probs,
             character_mapper=self.mapper,
-            blank_label=79,
+            blank_label=80,
             collapse_repeated=True,
         )
 
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 72f18b8..2493d5c 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -24,7 +24,7 @@ def greedy_decoder(
         target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
         character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters.  Defaults
             to None.
-        blank_label (int): The blank character to be ignored. Defaults to 79.
+        blank_label (int): The blank character to be ignored. Defaults to 80.
         collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
 
     Returns:
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 988b615..5c57479 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module):
         self.hidden_size = hidden_size
         self.encoder = self._configure_encoder(encoder)
         self.flatten = flatten
+        self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size)
         self.rnn = nn.LSTM(
-            input_size=self.input_size,
+            input_size=self.hidden_size,
             hidden_size=self.hidden_size,
             num_layers=num_layers,
         )
@@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module):
         # Avgerage pooling.
         x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
 
+        # Linear layer between CNN and RNN
+        x = self.fc(x)
+
         # Sequence predictions.
         x, _ = self.rnn(x)
 
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index cac9e78..1f853e9 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -34,6 +34,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
     """Returns the callable activation function."""
     activation_fns = nn.ModuleDict(
         [
+            ["elu", nn.ELU(inplace=True)],
             ["gelu", nn.GELU()],
             ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
             ["none", nn.Identity()],
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/text_recognizer/networks/transformer.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index d1c8f9a..618f414 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -28,10 +28,10 @@ def conv_init(module: Type[nn.Module]) -> None:
     classname = module.__class__.__name__
     if classname.find("Conv") != -1:
         nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
-        nn.init.constant(module.bias, 0)
+        nn.init.constant_(module.bias, 0)
     elif classname.find("BatchNorm") != -1:
-        nn.init.constant(module.weight, 1)
-        nn.init.constant(module.bias, 0)
+        nn.init.constant_(module.weight, 1)
+        nn.init.constant_(module.bias, 0)
 
 
 class WideBlock(nn.Module):
@@ -183,7 +183,7 @@ class WideResidualNetwork(nn.Module):
             else None
         )
 
-        self.apply(conv_init)
+        # self.apply(conv_init)
 
     def _configure_wide_layer(
         self, in_planes: int, out_planes: int, stride: int, activation: str
diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
index 9aec6ae..59c06c2 100644
Binary files a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt and b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt differ
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
new file mode 100644
index 0000000..9bd8ca2
Binary files /dev/null and b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt differ
-- 
cgit v1.2.3-70-g09d2