diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
commit | beeaef529e7c893a3475fe27edc880e283373725 (patch) | |
tree | 59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer | |
parent | 4d7713746eb936832e84852e90292936b933e87d (diff) |
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer')
27 files changed, 453 insertions, 85 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py index df37e68..ad71289 100644 --- a/src/text_recognizer/character_predictor.py +++ b/src/text_recognizer/character_predictor.py @@ -4,6 +4,7 @@ from typing import Dict, Tuple, Type, Union import numpy as np from torch import nn +from text_recognizer import datasets, networks from text_recognizer.models import CharacterModel from text_recognizer.util import read_image @@ -11,9 +12,11 @@ from text_recognizer.util import read_image class CharacterPredictor: """Recognizes the character in handwritten character images.""" - def __init__(self, network_fn: Type[nn.Module]) -> None: + def __init__(self, network_fn: str, dataset: str) -> None: """Intializes the CharacterModel and load the pretrained weights.""" - self.model = CharacterModel(network_fn=network_fn) + network_fn = getattr(networks, network_fn) + dataset = getattr(datasets, dataset) + self.model = CharacterModel(network_fn=network_fn, dataset=dataset) self.model.eval() self.model.use_swa_model() diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index a8901d6..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -22,6 +22,7 @@ class EmnistDataset(Dataset): def __init__( self, + pad_token: str = None, train: bool = False, sample_to_balance: bool = False, subsample_fraction: float = None, @@ -32,6 +33,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: + pad_token (str): The pad token symbol. Defaults to _. train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. @@ -45,6 +47,7 @@ class EmnistDataset(Dataset): subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + pad_token=pad_token, ) self.sample_to_balance = sample_to_balance @@ -53,6 +56,8 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) + self.target_transform = None + self.seed = seed def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6091da8..6871492 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -4,6 +4,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union +import click import h5py from loguru import logger import numpy as np @@ -58,13 +59,15 @@ class EmnistLinesDataset(Dataset): eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. """ + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, transform=transform, target_transform=target_transform, subsample_fraction=subsample_fraction, init_token=init_token, - pad_token=pad_token, + pad_token=self.pad_token, eos_token=eos_token, ) @@ -127,11 +130,7 @@ class EmnistLinesDataset(Dataset): @property def data_filename(self) -> Path: """Path to the h5 file.""" - filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" - if self.train: - filename = "train_" + filename - else: - filename = "test_" + filename + filename = "train.pt" if self.train else "test.pt" return DATA_DIRNAME / filename def load_or_generate_data(self) -> None: @@ -147,8 +146,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][:] - self._targets = f["targets"][:] + self._data = f["data"][()] + self._targets = f["targets"][()] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" @@ -157,7 +156,9 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) emnist.load_or_generate_data() samples_by_character = get_samples_by_character( @@ -308,6 +309,18 @@ def convert_strings_to_categorical_labels( return np.array([[mapping[c] for c in label] for label in labels]) +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") def create_datasets( max_length: int = 34, min_overlap: float = 0, @@ -326,3 +339,7 @@ def create_datasets( num_samples=num, ) emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index c058972..8deac7f 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,7 +3,7 @@ import numpy as np from PIL import Image import torch from torch import Tensor -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -19,28 +19,35 @@ class Transpose: class AddTokens: """Adds start of sequence and end of sequence tokens to target tensor.""" - def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None: + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) self.pad_value = self.emnist_mapper(self.pad_token) - self.sos_value = self.emnist_mapper(self.init_token) self.eos_value = self.emnist_mapper(self.eos_token) def __call__(self, target: Tensor) -> Tensor: """Adds a sos token to the begining and a eos token to the end of a target sequence.""" dtype, device = target.dtype, target.device - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) # Find the where padding starts. pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() target[pad_index] = self.eos_value - target = torch.cat([sos, target], dim=0) + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + return target diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py new file mode 100644 index 0000000..981e2c9 --- /dev/null +++ b/src/text_recognizer/line_predictor.py @@ -0,0 +1,28 @@ +"""LinePredictor class.""" +import importlib +from typing import Tuple, Union + +import numpy as np +from torch import nn + +from text_recognizer import datasets, networks +from text_recognizer.models import VisionTransformerModel +from text_recognizer.util import read_image + + +class LinePredictor: + """Given an image of a line of handwritten text, recognizes the text content.""" + + def __init__(self, dataset: str, network_fn: str) -> None: + network_fn = getattr(networks, network_fn) + dataset = getattr(datasets, dataset) + self.model = VisionTransformerModel(network_fn=network_fn, dataset=dataset) + self.model.eval() + + def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: + """Predict on a single images contianing a handwritten character.""" + if isinstance(image_or_filename, str): + image = read_image(image_or_filename, grayscale=True) + else: + image = image_or_filename + return self.model.predict_on_image(image) diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 0855079..28aa52e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,16 +1,19 @@ """Model modules.""" from .base import Model from .character_model import CharacterModel -from .line_ctc_model import LineCTCModel +from .crnn_model import CRNNModel from .metrics import accuracy, cer, wer +from .transformer_encoder_model import TransformerEncoderModel from .vision_transformer_model import VisionTransformerModel __all__ = [ "Model", "cer", "CharacterModel", + "CRNNModel", "CNNTransfromerModel", - "LineCTCModel", "accuracy", + "TransformerEncoderModel", + "VisionTransformerModel", "wer", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cbef787..cc44c92 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -141,11 +141,12 @@ class Model(ABC): "transform" in self.dataset_args["args"] and self.dataset_args["args"]["transform"] is not None ): - transform_ = [ - getattr(transforms_module, t["type"])() - for t in self.dataset_args["args"]["transform"] - ] + transform_ = [] + for t in self.dataset_args["args"]["transform"]: + args = t["args"] or {} + transform_.append(getattr(transforms_module, t["type"])(**args)) self.dataset_args["args"]["transform"] = Compose(transform_) + if ( "target_transform" in self.dataset_args["args"] and self.dataset_args["args"]["target_transform"] is not None diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 3cf6695..f9944f3 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -47,8 +47,9 @@ class CharacterModel(Model): swa_args, device, ) + self.pad_token = dataset_args["args"]["pad_token"] if self._mapper is None: - self._mapper = EmnistMapper() + self._mapper = EmnistMapper(pad_token=self.pad_token,) self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py index cdc2d8b..1e01a83 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/crnn_model.py @@ -1,4 +1,4 @@ -"""Defines the LineCTCModel class.""" +"""Defines the CRNNModel class.""" from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np @@ -13,7 +13,7 @@ from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder -class LineCTCModel(Model): +class CRNNModel(Model): """Model for predicting a sequence of characters from an image of a text line.""" def __init__( @@ -47,8 +47,10 @@ class LineCTCModel(Model): swa_args, device, ) + + self.pad_token = dataset_args["args"]["pad_token"] if self._mapper is None: - self._mapper = EmnistMapper() + self._mapper = EmnistMapper(pad_token=self.pad_token,) self.tensor_transform = ToTensor() def criterion(self, output: Tensor, targets: Tensor) -> Tensor: @@ -112,6 +114,6 @@ class LineCTCModel(Model): log_probs, _ = log_probs.max(dim=2) predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = torch.exp(-log_probs.sum()).item() + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 6a26216..42c3c6e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float: float: The accuracy for the batch. """ - _, predicted = torch.max(outputs.data, dim=1) + # eos_index = torch.nonzero(labels == eos, as_tuple=False) + # eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + _, predicted = torch.max(outputs, dim=-1) acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py new file mode 100644 index 0000000..e35e298 --- /dev/null +++ b/src/text_recognizer/models/transformer_encoder_model.py @@ -0,0 +1,111 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class TransformerEncoderModel(Model): + """A class for only using the encoder part in the sequence modelling.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + # self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + if network_args is not None: + self.max_len = network_args["max_len"] + else: + self.max_len = 128 + + if self._mapper is None: + self._mapper = EmnistMapper( + # init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.tensor_transform = ToTensor() + + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + logits = self.network(image) + # Convert logits to probabilities. + probs = self.softmax(logits).squeeze(0) + + confidence, pred_tokens = probs.max(1) + pred_tokens = pred_tokens + + eos_index = torch.nonzero( + pred_tokens == self._mapper(self.eos_token), as_tuple=False, + ) + + eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + predicted_characters = "".join( + [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] + ) + + confidence = np.min(confidence.tolist()) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py index 20bd4ca..3d36437 100644 --- a/src/text_recognizer/models/vision_transformer_model.py +++ b/src/text_recognizer/models/vision_transformer_model.py @@ -53,7 +53,7 @@ class VisionTransformerModel(Model): if network_args is not None: self.max_len = network_args["max_len"] else: - self.max_len = 128 + self.max_len = 120 if self._mapper is None: self._mapper = EmnistMapper( @@ -73,10 +73,10 @@ class VisionTransformerModel(Model): confidence_of_predictions = [] trg_indices = [self.mapper(self.init_token)] - for _ in range(self.max_len): + for _ in range(self.max_len - 1): trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg, trg_mask = self.network.preprocess_target(trg) - logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + trg = self.network.preprocess_target(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) # Convert logits to probabilities. probs = self.softmax(logits) @@ -112,6 +112,8 @@ class VisionTransformerModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - predicted_characters, confidence_of_prediction = self._generate_sentence(image) + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 8b87797..6d88768 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,5 +1,6 @@ """Network modules.""" from .cnn_transformer import CNNTransformer +from .cnn_transformer_encoder import CNNTransformerEncoder from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet @@ -15,6 +16,7 @@ from .wide_resnet import WideResidualNetwork __all__ = [ "CNNTransformer", + "CNNTransformerEncoder", "ConvolutionalRecurrentNetwork", "DenseNet", "EmbeddingLoss", diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 8666f11..3da2c9f 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,8 +1,7 @@ """A DETR style transfomers but for text recognition.""" -from typing import Dict, Optional, Tuple, Type +from typing import Dict, Optional, Tuple -from einops.layers.torch import Rearrange -from loguru import logger +from einops import rearrange import torch from torch import nn from torch import Tensor @@ -21,23 +20,32 @@ class CNNTransformer(nn.Module): hidden_dim: int, vocab_size: int, num_heads: int, - max_len: int, + adaptive_pool_dim: Tuple, expansion_dim: int, dropout_rate: float, trg_pad_index: int, backbone: str, + out_channels: int, + max_len: int, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index - self.backbone_args = backbone_args + self.backbone = configure_backbone(backbone, backbone_args) self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) - self.collapse_spatial_dim = nn.Sequential( - Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim)) + + # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1) + + self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) + self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None ) + self.transformer = Transformer( num_encoder_layers, num_decoder_layers, @@ -47,7 +55,8 @@ class CNNTransformer(nn.Module): dropout_rate, activation, ) - self.head = nn.Linear(hidden_dim, vocab_size) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) def _create_trg_mask(self, trg: Tensor) -> Tensor: # Move this outside the transformer. @@ -83,8 +92,22 @@ class CNNTransformer(nn.Module): if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - src = self.collapse_spatial_dim(src) - src = self.position_encoding(src) + # src = self.conv(src) + if self.adaptive_pool is not None: + src = self.adaptive_pool(src) + H, W = src.shape[-2:] + src = rearrange(src, "b t h w -> b t (h w)") + + # construct positional encodings + pos = torch.cat( + [ + self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), + self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), + ], + dim=-1, + ).unsqueeze(0) + pos = rearrange(pos, "b h w l -> b l (h w)") + src = pos + 0.1 * src return src def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: @@ -97,15 +120,16 @@ class CNNTransformer(nn.Module): Tuple[Tensor, Tensor]: Encoded target tensor and target mask. """ - trg_mask = self._create_trg_mask(trg) trg = self.character_embedding(trg.long()) trg = self.position_encoding(trg) - return trg, trg_mask + return trg - def forward(self, x: Tensor, trg: Tensor) -> Tensor: + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - src = self.preprocess_input(x) - trg, trg_mask = self.preprocess_target(trg) - out = self.transformer(src, trg, trg_mask=trg_mask) + h = self.preprocess_input(x) + trg_mask = self._create_trg_mask(trg) + trg = self.preprocess_target(trg) + out = self.transformer(h, trg, trg_mask=trg_mask) + logits = self.head(out) return logits diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py new file mode 100644 index 0000000..93626bf --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer_encoder.py @@ -0,0 +1,73 @@ +"""Network with a CNN backend and a transformer encoder head.""" +from typing import Dict + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformerEncoder(nn.Module): + """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict, + mlp_dim: int, + d_model: int, + nhead: int = 8, + dropout_rate: float = 0.1, + activation: str = "relu", + num_layers: int = 6, + num_classes: int = 80, + num_channels: int = 256, + max_len: int = 97, + ) -> None: + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.dropout_rate = dropout_rate + self.activation = activation + self.num_layers = num_layers + + self.backbone = configure_backbone(backbone, backbone_args) + self.position_encoding = PositionalEncoding(d_model, dropout_rate) + self.encoder = self._configure_encoder() + + self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) + + self.mlp = nn.Linear(mlp_dim, d_model) + + self.head = nn.Linear(d_model, num_classes) + + def _configure_encoder(self) -> nn.TransformerEncoder: + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.nhead, + dropout=self.dropout_rate, + activation=self.activation, + ) + norm = nn.LayerNorm(self.d_model) + return nn.TransformerEncoder( + encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm + ) + + def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: + """Forward pass through the network.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + x = self.conv(self.backbone(x)) + x = rearrange(x, "b c h w -> b c (h w)") + x = self.mlp(x) + x = self.position_encoding(x) + x = rearrange(x, "b c h-> c b h") + x = self.encoder(x) + x = rearrange(x, "c b h-> b c h") + logits = self.head(x) + + return logits diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py index 3e605e2..9747429 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/src/text_recognizer/networks/crnn.py @@ -1,12 +1,9 @@ """LSTM with CTC for handwritten text recognition within a line.""" -import importlib -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Tuple from einops import rearrange, reduce -from einops.layers.torch import Rearrange, Reduce +from einops.layers.torch import Rearrange from loguru import logger -import torch from torch import nn from torch import Tensor @@ -28,16 +25,21 @@ class ConvolutionalRecurrentNetwork(nn.Module): patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), recurrent_cell: str = "lstm", + avg_pool: bool = False, + use_sliding_window: bool = True, ) -> None: super().__init__() self.backbone_args = backbone_args or {} self.patch_size = patch_size self.stride = stride - self.sliding_window = self._configure_sliding_window() + self.sliding_window = ( + self._configure_sliding_window() if use_sliding_window else None + ) self.input_size = input_size self.hidden_size = hidden_size self.backbone = configure_backbone(backbone, backbone_args) self.bidirectional = bidirectional + self.avg_pool = avg_pool if recurrent_cell.upper() in ["LSTM", "GRU"]: recurrent_cell = getattr(nn, recurrent_cell) @@ -76,15 +78,27 @@ class ConvolutionalRecurrentNetwork(nn.Module): """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" if len(x.shape) < 4: x = x[(None,) * (4 - len(x.shape))] - x = self.sliding_window(x) - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.backbone(x) + if self.sliding_window is not None: + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - # Avgerage pooling. - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + x = self.backbone(x) + + # Avgerage pooling. + if self.avg_pool: + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + else: + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) + else: + # Encode the entire image with a CNN, and use the channels as temporal dimension. + b = x.shape[0] + x = self.backbone(x) + x = rearrange(x, "b c h w -> c b (h w)", b=b) # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 2493d5c..af9b700 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -33,7 +33,7 @@ def greedy_decoder( """ if character_mapper is None: - character_mapper = EmnistMapper() + character_mapper = EmnistMapper(pad_token="_") # noqa: S106 predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") decoded_predictions = [] diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py index d2aad60..7dc58d9 100644 --- a/src/text_recognizer/networks/densenet.py +++ b/src/text_recognizer/networks/densenet.py @@ -72,7 +72,7 @@ class _DenseBlock(nn.Module): ) -> None: super().__init__() self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation + num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, ) def _build_dense_blocks( @@ -219,7 +219,7 @@ class DenseNet(nn.Module): def forward(self, x: Tensor) -> Tensor: """Forward pass of Densenet.""" - # If batch dimenstion is missing, it needs to be added. + # If batch dimenstion is missing, it will be added. if len(x.shape) < 4: x = x[(None,) * (4 - len(x.shape))] return self.densenet(x) diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py index ff843cf..cf9fa0d 100644 --- a/src/text_recognizer/networks/loss.py +++ b/src/text_recognizer/networks/loss.py @@ -1,10 +1,12 @@ """Implementations of custom loss functions.""" from pytorch_metric_learning import distances, losses, miners, reducers +import torch from torch import nn from torch import Tensor +from torch.autograd import Variable +import torch.nn.functional as F - -__all__ = ["EmbeddingLoss"] +__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] class EmbeddingLoss: @@ -32,3 +34,36 @@ class EmbeddingLoss: hard_pairs = self.miner(embeddings, labels) loss = self.loss_fn(embeddings, labels, hard_pairs) return loss + + +class LabelSmoothingCrossEntropy(nn.Module): + """Label smoothing loss function.""" + + def __init__( + self, + classes: int, + smoothing: float = 0.0, + ignore_index: int = None, + dim: int = -1, + ) -> None: + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.ignore_index = ignore_index + self.cls = classes + self.dim = dim + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """Calculates the loss.""" + pred = pred.log_softmax(dim=self.dim) + with torch.no_grad(): + # true_dist = pred.data.clone() + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (self.cls - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + if self.ignore_index is not None: + true_dist[:, self.ignore_index] = 0 + mask = torch.nonzero(target == self.ignore_index, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py index a47141b..1ba5537 100644 --- a/src/text_recognizer/networks/transformer/positional_encoding.py +++ b/src/text_recognizer/networks/transformer/positional_encoding.py @@ -13,6 +13,7 @@ class PositionalEncoding(nn.Module): ) -> None: super().__init__() self.dropout = nn.Dropout(p=dropout_rate) + self.max_len = max_len pe = torch.zeros(max_len, hidden_dim) position = torch.arange(0, max_len).unsqueeze(1) diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py deleted file mode 100644 index 8c391c8..0000000 --- a/src/text_recognizer/networks/transformer/sparse_transformer.py +++ /dev/null @@ -1 +0,0 @@ -"""Encoder and Decoder modules using spares activations.""" diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py index 1c9c7dd..c6e943e 100644 --- a/src/text_recognizer/networks/transformer/transformer.py +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -230,6 +230,7 @@ class Transformer(nn.Module): ) -> Tensor: """Forward pass through the transformer.""" if src.shape[0] != trg.shape[0]: + print(trg.shape) raise RuntimeError("The batch size of the src and trg must be the same.") if src.shape[2] != trg.shape[2]: raise RuntimeError( diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index 0d08506..b31e640 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -28,7 +28,7 @@ def sliding_window( c = images.shape[1] patches = unfold(images) patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], ) return patches @@ -77,7 +77,7 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: backbone = nn.Sequential( - *list(backbone.children())[0][: -backbone_args["remove_layers"]] + *list(backbone.children())[:][: -backbone_args["remove_layers"]] ) return backbone diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py index 4d204d3..f227954 100644 --- a/src/text_recognizer/networks/vision_transformer.py +++ b/src/text_recognizer/networks/vision_transformer.py @@ -29,9 +29,9 @@ class VisionTransformer(nn.Module): num_heads: int, max_len: int, expansion_dim: int, - mlp_dim: int, dropout_rate: float, trg_pad_index: int, + mlp_dim: Optional[int] = None, patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), activation: str = "gelu", @@ -46,6 +46,7 @@ class VisionTransformer(nn.Module): self.slidning_window = self._configure_sliding_window() self.character_embedding = nn.Embedding(vocab_size, hidden_dim) self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) + self.mlp_dim = mlp_dim self.use_backbone = False if backbone is None: @@ -54,6 +55,8 @@ class VisionTransformer(nn.Module): ) else: self.backbone = configure_backbone(backbone, backbone_args) + if mlp_dim: + self.mlp = nn.Linear(mlp_dim, hidden_dim) self.use_backbone = True self.transformer = Transformer( @@ -66,13 +69,7 @@ class VisionTransformer(nn.Module): activation, ) - self.head = nn.Sequential( - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, mlp_dim), - nn.GELU(), - nn.Dropout(p=dropout_rate), - nn.Linear(mlp_dim, vocab_size), - ) + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) def _configure_sliding_window(self) -> nn.Sequential: return nn.Sequential( @@ -110,7 +107,11 @@ class VisionTransformer(nn.Module): if self.use_backbone: x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) x = self.backbone(x) - x = rearrange(x, "(b t) h -> b t h", b=b, t=t) + if self.mlp_dim: + x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t) + x = self.mlp(x) + else: + x = rearrange(x, "(b t) h -> b t h", b=b, t=t) else: x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) x = self.linear_projection(x) diff --git a/src/text_recognizer/tests/test_line_predictor.py b/src/text_recognizer/tests/test_line_predictor.py new file mode 100644 index 0000000..eede4d4 --- /dev/null +++ b/src/text_recognizer/tests/test_line_predictor.py @@ -0,0 +1,35 @@ +"""Tests for LinePredictor.""" +import os +from pathlib import Path +import unittest + + +import editdistance +import numpy as np + +from text_recognizer.datasets import IamLinesDataset +from text_recognizer.line_predictor import LinePredictor +import text_recognizer.util as util + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestEmnistLinePredictor(unittest.TestCase): + """Test LinePredictor class on the EmnistLines dataset.""" + + def test_filename(self) -> None: + """Test that LinePredictor correctly predicts on single images, for several test images.""" + predictor = LinePredictor( + dataset="EmnistLineDataset", network_fn="CNNTransformer" + ) + + for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"): + pred, conf = predictor.predict(str(filename)) + true = str(filename.stem) + edit_distance = editdistance.eval(pred, true) / len(pred) + print( + f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}' + ) + self.assertLess(edit_distance, 0.2) diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt Binary files differnew file mode 100644 index 0000000..726c723 --- /dev/null +++ b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt Binary files differnew file mode 100644 index 0000000..2d5a89b --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt |