summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
commit3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch)
treee1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer
parent2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff)
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/dataset.py7
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py9
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
-rw-r--r--src/text_recognizer/datasets/transforms.py13
-rw-r--r--src/text_recognizer/datasets/util.py2
-rw-r--r--src/text_recognizer/models/line_ctc_model.py20
-rw-r--r--src/text_recognizer/networks/ctc.py2
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py6
-rw-r--r--src/text_recognizer/networks/misc.py1
-rw-r--r--src/text_recognizer/networks/transformer.py1
-rw-r--r--src/text_recognizer/networks/wide_resnet.py8
-rw-r--r--src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.ptbin15375126 -> 61946486 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.ptbin0 -> 45257014 bytes
13 files changed, 55 insertions, 23 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index f328a0f..05520e5 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -23,7 +23,7 @@ class Dataset(data.Dataset):
Args:
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
@@ -31,6 +31,7 @@ class Dataset(data.Dataset):
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
+
self.train = train
self.split = "train" if self.train else "test"
@@ -96,8 +97,8 @@ class Dataset(data.Dataset):
if self.subsample_fraction is None:
return
num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_subsample]
- self.targets = self.targets[:num_subsample]
+ self._data = self.data[:num_subsample]
+ self._targets = self.targets[:num_subsample]
def __len__(self) -> int:
"""Returns the length of the dataset."""
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 81268fb..d01dcee 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -13,17 +13,10 @@ from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, ToTensor
from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.transforms import Transpose
from text_recognizer.datasets.util import DATA_DIRNAME
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
-
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
-
-
class EmnistDataset(Dataset):
"""This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 8fa77cd..6268a01 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -19,7 +19,6 @@ from text_recognizer.datasets.util import (
EmnistMapper,
ESSENTIALS_FILENAME,
)
-from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -32,6 +31,7 @@ class EmnistLinesDataset(Dataset):
train: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ subsample_fraction: float = None,
max_length: int = 34,
min_overlap: float = 0,
max_overlap: float = 0.33,
@@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset):
train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
max_length (int): The maximum number of characters. Defaults to 34.
min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
@@ -52,7 +53,10 @@ class EmnistLinesDataset(Dataset):
"""
super().__init__(
- train=train, transform=transform, target_transform=target_transform,
+ train=train,
+ transform=transform,
+ target_transform=target_transform,
+ subsample_fraction=subsample_fraction,
)
# Extract dataset information.
@@ -128,6 +132,7 @@ class EmnistLinesDataset(Dataset):
if not self.data_filename.exists():
self._generate_data()
self._load_data()
+ self._subsample()
def _load_data(self) -> None:
"""Loads the dataset from the h5 file."""
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
new file mode 100644
index 0000000..17231a8
--- /dev/null
+++ b/src/text_recognizer/datasets/transforms.py
@@ -0,0 +1,13 @@
+"""Transforms for PyTorch datasets."""
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 3acf5db..73968a1 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -126,7 +126,7 @@ class EmnistMapper:
"?",
]
- # padding symbol
+ # padding symbol, and acts as blank symbol as well.
extra_symbols.append("_")
max_key = max(mapping.keys())
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 97308a7..af41f18 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -62,12 +62,26 @@ class LineCTCModel(Model):
Tensor: The CTC loss.
"""
+
+ # Input lengths on the form [T, B]
input_lengths = torch.full(
size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
)
- target_lengths = torch.full(
- size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 79]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
+
return self.criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
@@ -93,7 +107,7 @@ class LineCTCModel(Model):
raw_pred, _ = greedy_decoder(
predictions=log_probs,
character_mapper=self.mapper,
- blank_label=79,
+ blank_label=80,
collapse_repeated=True,
)
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 72f18b8..2493d5c 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -24,7 +24,7 @@ def greedy_decoder(
target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
to None.
- blank_label (int): The blank character to be ignored. Defaults to 79.
+ blank_label (int): The blank character to be ignored. Defaults to 80.
collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
Returns:
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 988b615..5c57479 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -33,8 +33,9 @@ class LineRecurrentNetwork(nn.Module):
self.hidden_size = hidden_size
self.encoder = self._configure_encoder(encoder)
self.flatten = flatten
+ self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size)
self.rnn = nn.LSTM(
- input_size=self.input_size,
+ input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_layers=num_layers,
)
@@ -73,6 +74,9 @@ class LineRecurrentNetwork(nn.Module):
# Avgerage pooling.
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
+ # Linear layer between CNN and RNN
+ x = self.fc(x)
+
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index cac9e78..1f853e9 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -34,6 +34,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
"""Returns the callable activation function."""
activation_fns = nn.ModuleDict(
[
+ ["elu", nn.ELU(inplace=True)],
["gelu", nn.GELU()],
["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
["none", nn.Identity()],
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/text_recognizer/networks/transformer.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index d1c8f9a..618f414 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -28,10 +28,10 @@ def conv_init(module: Type[nn.Module]) -> None:
classname = module.__class__.__name__
if classname.find("Conv") != -1:
nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
- nn.init.constant(module.bias, 0)
+ nn.init.constant_(module.bias, 0)
elif classname.find("BatchNorm") != -1:
- nn.init.constant(module.weight, 1)
- nn.init.constant(module.bias, 0)
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
class WideBlock(nn.Module):
@@ -183,7 +183,7 @@ class WideResidualNetwork(nn.Module):
else None
)
- self.apply(conv_init)
+ # self.apply(conv_init)
def _configure_wide_layer(
self, in_planes: int, out_planes: int, stride: int, activation: str
diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
index 9aec6ae..59c06c2 100644
--- a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
new file mode 100644
index 0000000..9bd8ca2
--- /dev/null
+++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ