summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/character_predictor.py7
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py5
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py35
-rw-r--r--src/text_recognizer/datasets/transforms.py27
-rw-r--r--src/text_recognizer/line_predictor.py28
-rw-r--r--src/text_recognizer/models/__init__.py7
-rw-r--r--src/text_recognizer/models/base.py9
-rw-r--r--src/text_recognizer/models/character_model.py3
-rw-r--r--src/text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/line_ctc_model.py)10
-rw-r--r--src/text_recognizer/models/metrics.py5
-rw-r--r--src/text_recognizer/models/transformer_encoder_model.py111
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py12
-rw-r--r--src/text_recognizer/networks/__init__.py2
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py58
-rw-r--r--src/text_recognizer/networks/cnn_transformer_encoder.py73
-rw-r--r--src/text_recognizer/networks/crnn.py40
-rw-r--r--src/text_recognizer/networks/ctc.py2
-rw-r--r--src/text_recognizer/networks/densenet.py4
-rw-r--r--src/text_recognizer/networks/loss.py39
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py1
-rw-r--r--src/text_recognizer/networks/transformer/sparse_transformer.py1
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py1
-rw-r--r--src/text_recognizer/networks/util.py4
-rw-r--r--src/text_recognizer/networks/vision_transformer.py19
-rw-r--r--src/text_recognizer/tests/test_line_predictor.py35
-rw-r--r--src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.ptbin0 -> 5628749 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.ptbin0 -> 14953410 bytes
27 files changed, 453 insertions, 85 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py
index df37e68..ad71289 100644
--- a/src/text_recognizer/character_predictor.py
+++ b/src/text_recognizer/character_predictor.py
@@ -4,6 +4,7 @@ from typing import Dict, Tuple, Type, Union
import numpy as np
from torch import nn
+from text_recognizer import datasets, networks
from text_recognizer.models import CharacterModel
from text_recognizer.util import read_image
@@ -11,9 +12,11 @@ from text_recognizer.util import read_image
class CharacterPredictor:
"""Recognizes the character in handwritten character images."""
- def __init__(self, network_fn: Type[nn.Module]) -> None:
+ def __init__(self, network_fn: str, dataset: str) -> None:
"""Intializes the CharacterModel and load the pretrained weights."""
- self.model = CharacterModel(network_fn=network_fn)
+ network_fn = getattr(networks, network_fn)
+ dataset = getattr(datasets, dataset)
+ self.model = CharacterModel(network_fn=network_fn, dataset=dataset)
self.model.eval()
self.model.use_swa_model()
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index a8901d6..9884fdf 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -22,6 +22,7 @@ class EmnistDataset(Dataset):
def __init__(
self,
+ pad_token: str = None,
train: bool = False,
sample_to_balance: bool = False,
subsample_fraction: float = None,
@@ -32,6 +33,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
+ pad_token (str): The pad token symbol. Defaults to _.
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
@@ -45,6 +47,7 @@ class EmnistDataset(Dataset):
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ pad_token=pad_token,
)
self.sample_to_balance = sample_to_balance
@@ -53,6 +56,8 @@ class EmnistDataset(Dataset):
if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
+ self.target_transform = None
+
self.seed = seed
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6091da8..6871492 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -4,6 +4,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
+import click
import h5py
from loguru import logger
import numpy as np
@@ -58,13 +59,15 @@ class EmnistLinesDataset(Dataset):
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
"""
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
transform=transform,
target_transform=target_transform,
subsample_fraction=subsample_fraction,
init_token=init_token,
- pad_token=pad_token,
+ pad_token=self.pad_token,
eos_token=eos_token,
)
@@ -127,11 +130,7 @@ class EmnistLinesDataset(Dataset):
@property
def data_filename(self) -> Path:
"""Path to the h5 file."""
- filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
- if self.train:
- filename = "train_" + filename
- else:
- filename = "test_" + filename
+ filename = "train.pt" if self.train else "test.pt"
return DATA_DIRNAME / filename
def load_or_generate_data(self) -> None:
@@ -147,8 +146,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self._data = f["data"][:]
- self._targets = f["targets"][:]
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
@@ -157,7 +156,9 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
emnist.load_or_generate_data()
samples_by_character = get_samples_by_character(
@@ -308,6 +309,18 @@ def convert_strings_to_categorical_labels(
return np.array([[mapping[c] for c in label] for label in labels])
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
def create_datasets(
max_length: int = 34,
min_overlap: float = 0,
@@ -326,3 +339,7 @@ def create_datasets(
num_samples=num,
)
emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index c058972..8deac7f 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,7 +3,7 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torchvision.transforms import Compose, ToTensor
+from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
from text_recognizer.datasets.util import EmnistMapper
@@ -19,28 +19,35 @@ class Transpose:
class AddTokens:
"""Adds start of sequence and end of sequence tokens to target tensor."""
- def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None:
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
self.pad_value = self.emnist_mapper(self.pad_token)
- self.sos_value = self.emnist_mapper(self.init_token)
self.eos_value = self.emnist_mapper(self.eos_token)
def __call__(self, target: Tensor) -> Tensor:
"""Adds a sos token to the begining and a eos token to the end of a target sequence."""
dtype, device = target.dtype, target.device
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
# Find the where padding starts.
pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
target[pad_index] = self.eos_value
- target = torch.cat([sos, target], dim=0)
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
return target
diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py
new file mode 100644
index 0000000..981e2c9
--- /dev/null
+++ b/src/text_recognizer/line_predictor.py
@@ -0,0 +1,28 @@
+"""LinePredictor class."""
+import importlib
+from typing import Tuple, Union
+
+import numpy as np
+from torch import nn
+
+from text_recognizer import datasets, networks
+from text_recognizer.models import VisionTransformerModel
+from text_recognizer.util import read_image
+
+
+class LinePredictor:
+ """Given an image of a line of handwritten text, recognizes the text content."""
+
+ def __init__(self, dataset: str, network_fn: str) -> None:
+ network_fn = getattr(networks, network_fn)
+ dataset = getattr(datasets, dataset)
+ self.model = VisionTransformerModel(network_fn=network_fn, dataset=dataset)
+ self.model.eval()
+
+ def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
+ """Predict on a single images contianing a handwritten character."""
+ if isinstance(image_or_filename, str):
+ image = read_image(image_or_filename, grayscale=True)
+ else:
+ image = image_or_filename
+ return self.model.predict_on_image(image)
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index 0855079..28aa52e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,16 +1,19 @@
"""Model modules."""
from .base import Model
from .character_model import CharacterModel
-from .line_ctc_model import LineCTCModel
+from .crnn_model import CRNNModel
from .metrics import accuracy, cer, wer
+from .transformer_encoder_model import TransformerEncoderModel
from .vision_transformer_model import VisionTransformerModel
__all__ = [
"Model",
"cer",
"CharacterModel",
+ "CRNNModel",
"CNNTransfromerModel",
- "LineCTCModel",
"accuracy",
+ "TransformerEncoderModel",
+ "VisionTransformerModel",
"wer",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index cbef787..cc44c92 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -141,11 +141,12 @@ class Model(ABC):
"transform" in self.dataset_args["args"]
and self.dataset_args["args"]["transform"] is not None
):
- transform_ = [
- getattr(transforms_module, t["type"])()
- for t in self.dataset_args["args"]["transform"]
- ]
+ transform_ = []
+ for t in self.dataset_args["args"]["transform"]:
+ args = t["args"] or {}
+ transform_.append(getattr(transforms_module, t["type"])(**args))
self.dataset_args["args"]["transform"] = Compose(transform_)
+
if (
"target_transform" in self.dataset_args["args"]
and self.dataset_args["args"]["target_transform"] is not None
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 3cf6695..f9944f3 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -47,8 +47,9 @@ class CharacterModel(Model):
swa_args,
device,
)
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py
index cdc2d8b..1e01a83 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/crnn_model.py
@@ -1,4 +1,4 @@
-"""Defines the LineCTCModel class."""
+"""Defines the CRNNModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
@@ -13,7 +13,7 @@ from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
-class LineCTCModel(Model):
+class CRNNModel(Model):
"""Model for predicting a sequence of characters from an image of a text line."""
def __init__(
@@ -47,8 +47,10 @@ class LineCTCModel(Model):
swa_args,
device,
)
+
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
@@ -112,6 +114,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(-log_probs.sum()).item()
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 6a26216..42c3c6e 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- _, predicted = torch.max(outputs.data, dim=1)
+ # eos_index = torch.nonzero(labels == eos, as_tuple=False)
+ # eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ _, predicted = torch.max(outputs, dim=-1)
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py
new file mode 100644
index 0000000..e35e298
--- /dev/null
+++ b/src/text_recognizer/models/transformer_encoder_model.py
@@ -0,0 +1,111 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class TransformerEncoderModel(Model):
+ """A class for only using the encoder part in the sequence modelling."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ # self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 128
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ # init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ logits = self.network(image)
+ # Convert logits to probabilities.
+ probs = self.softmax(logits).squeeze(0)
+
+ confidence, pred_tokens = probs.max(1)
+ pred_tokens = pred_tokens
+
+ eos_index = torch.nonzero(
+ pred_tokens == self._mapper(self.eos_token), as_tuple=False,
+ )
+
+ eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ predicted_characters = "".join(
+ [self.mapper(x) for x in pred_tokens[:eos_index].tolist()]
+ )
+
+ confidence = np.min(confidence.tolist())
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
index 20bd4ca..3d36437 100644
--- a/src/text_recognizer/models/vision_transformer_model.py
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -53,7 +53,7 @@ class VisionTransformerModel(Model):
if network_args is not None:
self.max_len = network_args["max_len"]
else:
- self.max_len = 128
+ self.max_len = 120
if self._mapper is None:
self._mapper = EmnistMapper(
@@ -73,10 +73,10 @@ class VisionTransformerModel(Model):
confidence_of_predictions = []
trg_indices = [self.mapper(self.init_token)]
- for _ in range(self.max_len):
+ for _ in range(self.max_len - 1):
trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg, trg_mask = self.network.preprocess_target(trg)
- logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ trg = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
# Convert logits to probabilities.
probs = self.softmax(logits)
@@ -112,6 +112,8 @@ class VisionTransformerModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- predicted_characters, confidence_of_prediction = self._generate_sentence(image)
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 8b87797..6d88768 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,5 +1,6 @@
"""Network modules."""
from .cnn_transformer import CNNTransformer
+from .cnn_transformer_encoder import CNNTransformerEncoder
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
@@ -15,6 +16,7 @@ from .wide_resnet import WideResidualNetwork
__all__ = [
"CNNTransformer",
+ "CNNTransformerEncoder",
"ConvolutionalRecurrentNetwork",
"DenseNet",
"EmbeddingLoss",
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 8666f11..3da2c9f 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,8 +1,7 @@
"""A DETR style transfomers but for text recognition."""
-from typing import Dict, Optional, Tuple, Type
+from typing import Dict, Optional, Tuple
-from einops.layers.torch import Rearrange
-from loguru import logger
+from einops import rearrange
import torch
from torch import nn
from torch import Tensor
@@ -21,23 +20,32 @@ class CNNTransformer(nn.Module):
hidden_dim: int,
vocab_size: int,
num_heads: int,
- max_len: int,
+ adaptive_pool_dim: Tuple,
expansion_dim: int,
dropout_rate: float,
trg_pad_index: int,
backbone: str,
+ out_channels: int,
+ max_len: int,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
- self.backbone_args = backbone_args
+
self.backbone = configure_backbone(backbone, backbone_args)
self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
- self.collapse_spatial_dim = nn.Sequential(
- Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim))
+
+ # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1)
+
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+ self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
)
+
self.transformer = Transformer(
num_encoder_layers,
num_decoder_layers,
@@ -47,7 +55,8 @@ class CNNTransformer(nn.Module):
dropout_rate,
activation,
)
- self.head = nn.Linear(hidden_dim, vocab_size)
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
def _create_trg_mask(self, trg: Tensor) -> Tensor:
# Move this outside the transformer.
@@ -83,8 +92,22 @@ class CNNTransformer(nn.Module):
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
src = self.backbone(src)
- src = self.collapse_spatial_dim(src)
- src = self.position_encoding(src)
+ # src = self.conv(src)
+ if self.adaptive_pool is not None:
+ src = self.adaptive_pool(src)
+ H, W = src.shape[-2:]
+ src = rearrange(src, "b t h w -> b t (h w)")
+
+ # construct positional encodings
+ pos = torch.cat(
+ [
+ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
+ self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
+ ],
+ dim=-1,
+ ).unsqueeze(0)
+ pos = rearrange(pos, "b h w l -> b l (h w)")
+ src = pos + 0.1 * src
return src
def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
@@ -97,15 +120,16 @@ class CNNTransformer(nn.Module):
Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
"""
- trg_mask = self._create_trg_mask(trg)
trg = self.character_embedding(trg.long())
trg = self.position_encoding(trg)
- return trg, trg_mask
+ return trg
- def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
"""Forward pass with CNN transfomer."""
- src = self.preprocess_input(x)
- trg, trg_mask = self.preprocess_target(trg)
- out = self.transformer(src, trg, trg_mask=trg_mask)
+ h = self.preprocess_input(x)
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.preprocess_target(trg)
+ out = self.transformer(h, trg, trg_mask=trg_mask)
+
logits = self.head(out)
return logits
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
new file mode 100644
index 0000000..93626bf
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer_encoder.py
@@ -0,0 +1,73 @@
+"""Network with a CNN backend and a transformer encoder head."""
+from typing import Dict
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformerEncoder(nn.Module):
+ """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict,
+ mlp_dim: int,
+ d_model: int,
+ nhead: int = 8,
+ dropout_rate: float = 0.1,
+ activation: str = "relu",
+ num_layers: int = 6,
+ num_classes: int = 80,
+ num_channels: int = 256,
+ max_len: int = 97,
+ ) -> None:
+ super().__init__()
+ self.d_model = d_model
+ self.nhead = nhead
+ self.dropout_rate = dropout_rate
+ self.activation = activation
+ self.num_layers = num_layers
+
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.position_encoding = PositionalEncoding(d_model, dropout_rate)
+ self.encoder = self._configure_encoder()
+
+ self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
+
+ self.mlp = nn.Linear(mlp_dim, d_model)
+
+ self.head = nn.Linear(d_model, num_classes)
+
+ def _configure_encoder(self) -> nn.TransformerEncoder:
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=self.d_model,
+ nhead=self.nhead,
+ dropout=self.dropout_rate,
+ activation=self.activation,
+ )
+ norm = nn.LayerNorm(self.d_model)
+ return nn.TransformerEncoder(
+ encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
+ )
+
+ def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
+ """Forward pass through the network."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ x = self.conv(self.backbone(x))
+ x = rearrange(x, "b c h w -> b c (h w)")
+ x = self.mlp(x)
+ x = self.position_encoding(x)
+ x = rearrange(x, "b c h-> c b h")
+ x = self.encoder(x)
+ x = rearrange(x, "c b h-> b c h")
+ logits = self.head(x)
+
+ return logits
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
index 3e605e2..9747429 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -1,12 +1,9 @@
"""LSTM with CTC for handwritten text recognition within a line."""
-import importlib
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Dict, Tuple
from einops import rearrange, reduce
-from einops.layers.torch import Rearrange, Reduce
+from einops.layers.torch import Rearrange
from loguru import logger
-import torch
from torch import nn
from torch import Tensor
@@ -28,16 +25,21 @@ class ConvolutionalRecurrentNetwork(nn.Module):
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
) -> None:
super().__init__()
self.backbone_args = backbone_args or {}
self.patch_size = patch_size
self.stride = stride
- self.sliding_window = self._configure_sliding_window()
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
self.input_size = input_size
self.hidden_size = hidden_size
self.backbone = configure_backbone(backbone, backbone_args)
self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
if recurrent_cell.upper() in ["LSTM", "GRU"]:
recurrent_cell = getattr(nn, recurrent_cell)
@@ -76,15 +78,27 @@ class ConvolutionalRecurrentNetwork(nn.Module):
"""Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
- x = self.sliding_window(x)
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.backbone(x)
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- # Avgerage pooling.
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ x = self.backbone(x)
+
+ # Avgerage pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ b = x.shape[0]
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> c b (h w)", b=b)
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 2493d5c..af9b700 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -33,7 +33,7 @@ def greedy_decoder(
"""
if character_mapper is None:
- character_mapper = EmnistMapper()
+ character_mapper = EmnistMapper(pad_token="_") # noqa: S106
predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
decoded_predictions = []
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
index d2aad60..7dc58d9 100644
--- a/src/text_recognizer/networks/densenet.py
+++ b/src/text_recognizer/networks/densenet.py
@@ -72,7 +72,7 @@ class _DenseBlock(nn.Module):
) -> None:
super().__init__()
self.dense_block = self._build_dense_blocks(
- num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
)
def _build_dense_blocks(
@@ -219,7 +219,7 @@ class DenseNet(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of Densenet."""
- # If batch dimenstion is missing, it needs to be added.
+ # If batch dimenstion is missing, it will be added.
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
return self.densenet(x)
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py
index ff843cf..cf9fa0d 100644
--- a/src/text_recognizer/networks/loss.py
+++ b/src/text_recognizer/networks/loss.py
@@ -1,10 +1,12 @@
"""Implementations of custom loss functions."""
from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
from torch import nn
from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
-
-__all__ = ["EmbeddingLoss"]
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
class EmbeddingLoss:
@@ -32,3 +34,36 @@ class EmbeddingLoss:
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_fn(embeddings, labels, hard_pairs)
return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+ """Label smoothing loss function."""
+
+ def __init__(
+ self,
+ classes: int,
+ smoothing: float = 0.0,
+ ignore_index: int = None,
+ dim: int = -1,
+ ) -> None:
+ super().__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.ignore_index = ignore_index
+ self.cls = classes
+ self.dim = dim
+
+ def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+ """Calculates the loss."""
+ pred = pred.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ # true_dist = pred.data.clone()
+ true_dist = torch.zeros_like(pred)
+ true_dist.fill_(self.smoothing / (self.cls - 1))
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+ if self.ignore_index is not None:
+ true_dist[:, self.ignore_index] = 0
+ mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+ if mask.dim() > 0:
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
index a47141b..1ba5537 100644
--- a/src/text_recognizer/networks/transformer/positional_encoding.py
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -13,6 +13,7 @@ class PositionalEncoding(nn.Module):
) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len).unsqueeze(1)
diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py
deleted file mode 100644
index 8c391c8..0000000
--- a/src/text_recognizer/networks/transformer/sparse_transformer.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Encoder and Decoder modules using spares activations."""
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index 1c9c7dd..c6e943e 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -230,6 +230,7 @@ class Transformer(nn.Module):
) -> Tensor:
"""Forward pass through the transformer."""
if src.shape[0] != trg.shape[0]:
+ print(trg.shape)
raise RuntimeError("The batch size of the src and trg must be the same.")
if src.shape[2] != trg.shape[2]:
raise RuntimeError(
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index 0d08506..b31e640 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -28,7 +28,7 @@ def sliding_window(
c = images.shape[1]
patches = unfold(images)
patches = rearrange(
- patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
)
return patches
@@ -77,7 +77,7 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
backbone = nn.Sequential(
- *list(backbone.children())[0][: -backbone_args["remove_layers"]]
+ *list(backbone.children())[:][: -backbone_args["remove_layers"]]
)
return backbone
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
index 4d204d3..f227954 100644
--- a/src/text_recognizer/networks/vision_transformer.py
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -29,9 +29,9 @@ class VisionTransformer(nn.Module):
num_heads: int,
max_len: int,
expansion_dim: int,
- mlp_dim: int,
dropout_rate: float,
trg_pad_index: int,
+ mlp_dim: Optional[int] = None,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
activation: str = "gelu",
@@ -46,6 +46,7 @@ class VisionTransformer(nn.Module):
self.slidning_window = self._configure_sliding_window()
self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+ self.mlp_dim = mlp_dim
self.use_backbone = False
if backbone is None:
@@ -54,6 +55,8 @@ class VisionTransformer(nn.Module):
)
else:
self.backbone = configure_backbone(backbone, backbone_args)
+ if mlp_dim:
+ self.mlp = nn.Linear(mlp_dim, hidden_dim)
self.use_backbone = True
self.transformer = Transformer(
@@ -66,13 +69,7 @@ class VisionTransformer(nn.Module):
activation,
)
- self.head = nn.Sequential(
- nn.LayerNorm(hidden_dim),
- nn.Linear(hidden_dim, mlp_dim),
- nn.GELU(),
- nn.Dropout(p=dropout_rate),
- nn.Linear(mlp_dim, vocab_size),
- )
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
@@ -110,7 +107,11 @@ class VisionTransformer(nn.Module):
if self.use_backbone:
x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
x = self.backbone(x)
- x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+ if self.mlp_dim:
+ x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t)
+ x = self.mlp(x)
+ else:
+ x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
else:
x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
x = self.linear_projection(x)
diff --git a/src/text_recognizer/tests/test_line_predictor.py b/src/text_recognizer/tests/test_line_predictor.py
new file mode 100644
index 0000000..eede4d4
--- /dev/null
+++ b/src/text_recognizer/tests/test_line_predictor.py
@@ -0,0 +1,35 @@
+"""Tests for LinePredictor."""
+import os
+from pathlib import Path
+import unittest
+
+
+import editdistance
+import numpy as np
+
+from text_recognizer.datasets import IamLinesDataset
+from text_recognizer.line_predictor import LinePredictor
+import text_recognizer.util as util
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestEmnistLinePredictor(unittest.TestCase):
+ """Test LinePredictor class on the EmnistLines dataset."""
+
+ def test_filename(self) -> None:
+ """Test that LinePredictor correctly predicts on single images, for several test images."""
+ predictor = LinePredictor(
+ dataset="EmnistLineDataset", network_fn="CNNTransformer"
+ )
+
+ for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"):
+ pred, conf = predictor.predict(str(filename))
+ true = str(filename.stem)
+ edit_distance = editdistance.eval(pred, true) / len(pred)
+ print(
+ f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}'
+ )
+ self.assertLess(edit_distance, 0.2)
diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
new file mode 100644
index 0000000..726c723
--- /dev/null
+++ b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
new file mode 100644
index 0000000..2d5a89b
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
Binary files differ