From 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 7 Dec 2020 22:54:04 +0100 Subject: Segmentation working! --- .../datasets/emnist_lines_dataset.py | 11 ++ .../datasets/iam_paragraphs_dataset.py | 8 +- src/text_recognizer/datasets/transforms.py | 18 +- src/text_recognizer/models/__init__.py | 2 + src/text_recognizer/models/base.py | 11 +- src/text_recognizer/models/segmentation_model.py | 75 ++++++++ src/text_recognizer/models/transformer_model.py | 4 +- src/text_recognizer/networks/__init__.py | 4 + src/text_recognizer/networks/beam.py | 83 +++++++++ src/text_recognizer/networks/cnn_transformer.py | 19 +- src/text_recognizer/networks/fcn.py | 99 ---------- .../networks/neural_machine_reader.py | 201 --------------------- src/text_recognizer/networks/residual_network.py | 7 +- src/text_recognizer/networks/unet.py | 159 ++++++++++++---- src/text_recognizer/paragraph_text_recognizer.py | 153 ++++++++++++++++ .../tests/support/iam_paragraphs/a01-000u.jpg | Bin 0 -> 14890 bytes .../tests/test_paragraph_text_recognizer.py | 37 ++++ src/text_recognizer/util.py | 21 ++- ...tationModel_IamParagraphsDataset_FCN_weights.pt | Bin 0 -> 8588813 bytes ...ationModel_IamParagraphsDataset_UNet_weights.pt | Bin 0 -> 92335101 bytes 20 files changed, 551 insertions(+), 361 deletions(-) create mode 100644 src/text_recognizer/models/segmentation_model.py create mode 100644 src/text_recognizer/networks/beam.py delete mode 100644 src/text_recognizer/networks/fcn.py delete mode 100644 src/text_recognizer/networks/neural_machine_reader.py create mode 100644 src/text_recognizer/paragraph_text_recognizer.py create mode 100644 src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg create mode 100644 src/text_recognizer/tests/test_paragraph_text_recognizer.py create mode 100644 src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt create mode 100644 src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6871492..eddf341 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -10,6 +10,7 @@ from loguru import logger import numpy as np import torch from torch import Tensor +import torch.nn.functional as F from torchvision.transforms import ToTensor from text_recognizer.datasets.dataset import Dataset @@ -23,6 +24,8 @@ from text_recognizer.datasets.util import ( DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" +MAX_WIDTH = 952 + class EmnistLinesDataset(Dataset): """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" @@ -254,6 +257,14 @@ def construct_image_from_string( for image in sampled_images: concatenated_image[:, x : (x + width)] += image x += next_overlap_width + + if concatenated_image.shape[-1] > MAX_WIDTH: + concatenated_image = Tensor(concatenated_image).unsqueeze(0) + concatenated_image = F.interpolate( + concatenated_image, size=MAX_WIDTH, mode="nearest" + ) + concatenated_image = concatenated_image.squeeze(0).numpy() + return np.minimum(255, concatenated_image) diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index c1e8fe2..8ba5142 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -1,4 +1,5 @@ """IamParagraphsDataset class and functions for data processing.""" +import random from typing import Callable, Dict, List, Optional, Tuple, Union import click @@ -71,13 +72,18 @@ class IamParagraphsDataset(Dataset): data = self.data[index] targets = self.targets[index] + seed = np.random.randint(SEED) + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 if self.transform: data = self.transform(data) + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 if self.target_transform: targets = self.target_transform(targets) - return data, targets + return data, targets.long() @property def ids(self) -> Tensor: diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 1ec23dc..016ec80 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -4,7 +4,7 @@ from PIL import Image import torch from torch import Tensor import torch.nn.functional as F -from torchvision.transforms import Compose, RandomAffine, ToTensor +from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -77,3 +77,19 @@ class ApplyContrast: """Apply mask binary mask to input tensor.""" mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) return x * mask + + +class Unsqueeze: + """Add a dimension to the tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Adds dim.""" + return x.unsqueeze(0) + + +class Squeeze: + """Removes the first dimension of a tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Removes first dim.""" + return x.squeeze(0) diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index bf89404..a645cec 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,11 +2,13 @@ from .base import Model from .character_model import CharacterModel from .crnn_model import CRNNModel +from .segmentation_model import SegmentationModel from .transformer_model import TransformerModel __all__ = [ "CharacterModel", "CRNNModel", "Model", + "SegmentationModel", "TransformerModel", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d394b4c..f2cd4b8 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -159,7 +159,7 @@ class Model(ABC): self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) self.test_dataset.load_or_generate_data() - # Set the flag to true to disable ability to load data agian. + # Set the flag to true to disable ability to load data again. self.data_prepared = True def train_dataloader(self) -> DataLoader: @@ -260,7 +260,7 @@ class Model(ABC): @property def mapping(self) -> Dict: """Returns the mapping between network output and Emnist character.""" - return self._mapper.mapping + return self._mapper.mapping if self._mapper is not None else None def eval(self) -> None: """Sets the network to evaluation mode.""" @@ -341,7 +341,7 @@ class Model(ABC): if input_shape is not None: summary(self.network, input_shape, depth=depth, device=device) elif self._input_shape is not None: - input_shape = (1,) + tuple(self._input_shape) + input_shape = tuple(self._input_shape) summary(self.network, input_shape, depth=depth, device=device) else: logger.warning("Could not print summary as input shape is not set.") @@ -427,7 +427,7 @@ class Model(ABC): ) shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> None: + def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -441,7 +441,8 @@ class Model(ABC): weights = state_dict["model_state"] # Initializes the network with trained weights. - self._network = network_fn(**self._network_args) + if network_fn is not None: + self._network = network_fn(**self._network_args) self._network.load_state_dict(weights) if "swa_network" in state_dict: diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py new file mode 100644 index 0000000..613108a --- /dev/null +++ b/src/text_recognizer/models/segmentation_model.py @@ -0,0 +1,75 @@ +"""Segmentation model for detecting and segmenting lines.""" +from typing import Callable, Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.models.base import Model + + +class SegmentationModel(Model): + """Model for segmenting lines in an image.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: + """Predict on a single input.""" + self.eval() + + if image.dtype is np.uint8: + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype is torch.uint8 or image.dtype is torch.int64: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + if not torch.is_tensor(image): + image = Tensor(image) + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + logits = self.forward(image) + + segmentation_mask = torch.argmax(logits, dim=1) + + return segmentation_mask diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 968a047..a912122 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -18,8 +18,8 @@ class TransformerModel(Model): def __init__( self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], + network_fn: str, + dataset: str, network_args: Optional[Dict] = None, dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 1635039..f958672 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -3,11 +3,13 @@ from .cnn_transformer import CNNTransformer from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet +from .fcn import FCN from .lenet import LeNet from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .transformer import Transformer +from .unet import UNet from .util import sliding_window from .wide_resnet import WideResidualNetwork @@ -18,12 +20,14 @@ __all__ = [ "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", + "FCN", "greedy_decoder", "MLP", "LeNet", "ResidualNetwork", "ResidualNetworkEncoder", "sliding_window", + "UNet", "Transformer", "wer", "WideResidualNetwork", diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py new file mode 100644 index 0000000..dccccdb --- /dev/null +++ b/src/text_recognizer/networks/beam.py @@ -0,0 +1,83 @@ +"""Implementation of beam search decoder for a sequence to sequence network. + +Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py + +""" +# from typing import List +# from Queue import PriorityQueue + +# from loguru import logger +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F + + +# class Node: +# def __init__( +# self, parent: Node, target_index: int, log_prob: Tensor, length: int +# ) -> None: +# self.parent = parent +# self.target_index = target_index +# self.log_prob = log_prob +# self.length = length +# self.reward = 0.0 + +# def eval(self, alpha: float = 1.0) -> Tensor: +# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward + + +# @torch.no_grad() +# def beam_decoder( +# network, mapper, device, memory: Tensor = None, max_len: int = 97, +# ) -> Tensor: +# beam_width = 10 +# topk = 1 # How many sentences to generate. + +# trg_indices = [mapper(mapper.init_token)] + +# end_nodes = [] + +# node = Node(None, trg_indices, 0, 1) +# nodes = PriorityQueue() + +# nodes.put((node.eval(), node)) +# q_size = 1 + +# # Beam search +# for _ in range(max_len): +# if q_size > 2000: +# logger.warning("Could not decoder input") +# break + +# # Fetch the best node. +# score, n = nodes.get() +# decoder_input = n.target_index + +# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: +# end_nodes.append((score, n)) + +# # If we reached the maximum number of sentences required. +# if len(end_nodes) >= 1: +# break +# else: +# continue + +# # Forward pass with transformer. +# trg = torch.tensor(trg_indices, device=device)[None, :].long() +# trg = network.target_embedding(trg) +# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) +# log_prob = F.log_softmax(logits, dim=2) + +# log_prob, indices = torch.topk(log_prob, beam_width) + +# for new_k in range(beam_width): +# # TODO: continue from here +# token_index = indices[0][new_k].view(1, -1) +# log_p = log_prob[0][new_k].item() + +# node = Node() + +# pass + +# pass diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 16c7a41..b2b74b3 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -88,10 +88,14 @@ class CNNTransformer(nn.Module): if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - src = rearrange(src, "b c h w -> b w c h") + if self.adaptive_pool is not None: + src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) - src = src.squeeze(3) + src = src.squeeze(3) + else: + src = rearrange(src, "b c h w -> b (w h) c") + src = self.position_encoding(src) return src @@ -110,12 +114,17 @@ class CNNTransformer(nn.Module): trg = self.position_encoding(trg) return trg - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) + def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" trg_mask = self._create_trg_mask(trg) trg = self.target_embedding(trg) out = self.transformer(h, trg, trg_mask=trg_mask) logits = self.head(out) return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.extract_image_features(x) + logits = self.decode_image_features(h, trg) + return logits diff --git a/src/text_recognizer/networks/fcn.py b/src/text_recognizer/networks/fcn.py deleted file mode 100644 index f9c4fd4..0000000 --- a/src/text_recognizer/networks/fcn.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Fully Convolutional Network (FCN) with dilated kernels for global context.""" -from typing import List, Tuple, Type -import torch -from torch import nn -from torch import Tensor - - -from text_recognizer.networks.util import activation_function - - -class _DilatedBlock(nn.Module): - def __init__( - self, - channels: List[int], - kernel_sizes: List[int], - dilations: List[int], - paddings: List[int], - activation_fn: Type[nn.Module], - ) -> None: - super().__init__() - self.dilation_conv = nn.Sequential( - nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1], - kernel_size=kernel_sizes[0], - stride=1, - dilation=dilations[0], - padding=paddings[0], - ), - nn.Conv2d( - in_channels=channels[1], - out_channels=channels[1] // 2, - kernel_size=kernel_sizes[1], - stride=1, - dilation=dilations[1], - padding=paddings[1], - ), - ) - self.activation_fn = activation_fn - - self.conv = nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1] // 2, - kernel_size=1, - dilation=1, - stride=1, - ) - - def forward(self, x: Tensor) -> Tensor: - residual = self.conv(x) - x = self.dilation_conv(x) - x = torch.cat((x, residual), dim=1) - return self.activation_fn(x) - - -class FCN(nn.Module): - def __init__( - self, - in_channels: int, - base_channels: int, - out_channels: int, - kernel_size: int, - dilations: Tuple[int] = (3, 7), - paddings: Tuple[int] = (9, 21), - num_blocks: int = 14, - activation: str = "elu", - ) -> None: - super().__init__() - self.kernel_sizes = [kernel_size] * num_blocks - self.channels = [in_channels] + [base_channels] * (num_blocks - 1) - self.out_channels = out_channels - self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * ( - num_blocks // 2 - ) - self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * ( - num_blocks // 2 - ) - self.activation_fn = activation_function(activation) - self.fcn = self._configure_fcn() - - def _configure_fcn(self) -> nn.Sequential: - layers = [] - for i in range(0, len(self.channels), 2): - layers.append( - _DilatedBlock( - self.channels[i : i + 2], - self.kernel_sizes[i : i + 2], - self.dilations[i : i + 2], - self.paddings[i : i + 2], - self.activation_fn, - ) - ) - layers.append( - nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1) - ) - return nn.Sequential(*layers) - - def forward(self, x: Tensor) -> Tensor: - return self.fcn(x) diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py deleted file mode 100644 index 7f8c49b..0000000 --- a/src/text_recognizer/networks/neural_machine_reader.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Sequence to sequence network with RNN cells.""" -# from typing import Dict, Optional, Tuple - -# from einops import rearrange -# from einops.layers.torch import Rearrange -# import torch -# from torch import nn -# from torch import Tensor - -# from text_recognizer.networks.util import configure_backbone - - -# class Encoder(nn.Module): -# def __init__( -# self, -# embedding_dim: int, -# encoder_dim: int, -# decoder_dim: int, -# dropout_rate: float = 0.1, -# ) -> None: -# super().__init__() -# self.rnn = nn.GRU( -# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True -# ) -# self.fc = nn.Sequential( -# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh() -# ) -# self.dropout = nn.Dropout(p=dropout_rate) - -# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: -# """Encodes a sequence of tensors with a bidirectional GRU. - -# Args: -# x (Tensor): A input sequence. - -# Shape: -# - x: :math:`(T, N, E)`. -# - output[0]: :math:`(T, N, 2 * E)`. -# - output[1]: :math:`(T, N, D)`. - -# where T is the sequence length, N is the batch size, E is the -# embedding/encoder dimension, and D is the decoder dimension. - -# Returns: -# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the -# encoder. - -# """ - -# output, hidden = self.rnn(x) - -# # Get the hidden state from the forward and backward rnn. -# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) - -# # Apply fully connected layer and tanh activation. -# hidden_state = self.fc(hidden_state) - -# return output, hidden_state - - -# class Attention(nn.Module): -# def __init__(self, encoder_dim: int, decoder_dim: int) -> None: -# super().__init__() -# self.atten = nn.Linear( -# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim -# ) -# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False) - -# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: -# """Short summary. - -# Args: -# hidden_state (Tensor): Description of parameter `h`. -# encoder_outputs (Tensor): Description of parameter `enc_out`. - -# Shape: -# - x: :math:`(T, N, E)`. -# - output[0]: :math:`(T, N, 2 * E)`. -# - output[1]: :math:`(T, N, D)`. - -# where T is the sequence length, N is the batch size, E is the -# embedding/encoder dimension, and D is the decoder dimension. - -# Returns: -# Tensor: Description of returned object. - -# """ -# t, b = enc_out.shape[:2] -# # repeat decoder hidden state src_len times -# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1) - -# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2") - -# # Calculate the energy between the decoders previous hidden state and the -# # encoders hidden states. -# energy = torch.tanh( -# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2)) -# ) - -# attention = self.value(energy).squeeze(2) - -# # Apply softmax on the attention to squeeze it between 0 and 1. -# attention = F.softmax(attention, dim=1) - -# return attention - - -# class Decoder(nn.Module): -# def __init__( -# self, -# embedding_dim: int, -# encoder_dim: int, -# decoder_dim: int, -# output_dim: int, -# dropout_rate: float = 0.1, -# ) -> None: -# super().__init__() -# self.output_dim = output_dim -# self.embedding = nn.Embedding(output_dim, embedding_dim) -# self.attention = Attention(encoder_dim, decoder_dim) -# self.rnn = nn.GRU( -# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim -# ) - -# self.head = nn.Linear( -# in_features=2 * encoder_dim + embedding_dim + decoder_dim, -# out_features=output_dim, -# ) -# self.dropout = nn.Dropout(p=dropout_rate) - -# def forward( -# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor -# ) -> Tensor: -# # input = [batch size] -# # hidden = [batch size, dec hid dim] -# # encoder_outputs = [src len, batch size, enc hid dim * 2] -# trg = trg.unsqueeze(0) -# trg_embedded = self.dropout(self.embedding(trg)) - -# a = self.attention(hidden_state, encoder_outputs) - -# weighted = torch.bmm(a, encoder_outputs) - -# # Permutate the tensor. -# weighted = rearrange(weighted, "b a e2 -> a b e2") - -# rnn_input = torch.cat((trg_embedded, weighted), dim=2) - -# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) - -# # seq len, n layers and n directions will always be 1 in this decoder, therefore: -# # output = [1, batch size, dec hid dim] -# # hidden = [1, batch size, dec hid dim] -# # this also means that output == hidden -# assert (output == hidden).all() - -# trg_embedded = trg_embedded.squeeze(0) -# output = output.squeeze(0) -# weighted = weighted.squeeze(0) - -# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1)) - -# # prediction = [batch size, output dim] - -# return logits, hidden.squeeze(0) - - -# class NeuralMachineReader(nn.Module): -# def __init__( -# self, -# embedding_dim: int, -# encoder_dim: int, -# decoder_dim: int, -# output_dim: int, -# backbone: Optional[str] = None, -# backbone_args: Optional[Dict] = None, -# adaptive_pool_dim: Tuple = (None, 1), -# dropout_rate: float = 0.1, -# teacher_forcing_ratio: float = 0.5, -# ) -> None: -# super().__init__() - -# self.backbone = configure_backbone(backbone, backbone_args) -# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim)) - -# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate) -# self.decoder = Decoder( -# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate -# ) -# self.teacher_forcing_ratio = teacher_forcing_ratio - -# def extract_image_features(self, x: Tensor) -> Tensor: -# x = self.backbone(x) -# x = rearrange(x, "b c h w -> b w c h") -# x = self.adaptive_pool(x) -# x = x.squeeze(3) - -# def forward(self, x: Tensor, trg: Tensor) -> Tensor: -# # x = [batch size, height, width] -# # trg = [trg len, batch size] -# z = self.extract_image_features(x) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 6405192..e397224 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -7,7 +7,6 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.stn import SpatialTransformerNetwork from text_recognizer.networks.util import activation_function @@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module): activation: str = "relu", block: Type[nn.Module] = BasicBlock, levels: int = 1, - stn: bool = False, *args, **kwargs ) -> None: super().__init__() - self.stn = SpatialTransformerNetwork() if stn else None self.block_sizes = ( block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels ) @@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) - if self.stn is not None: - x = self.stn(x) x = self.gate(x) x = self.blocks(x) return x diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py index 51f242a..510910f 100644 --- a/src/text_recognizer/networks/unet.py +++ b/src/text_recognizer/networks/unet.py @@ -8,64 +8,118 @@ from torch import Tensor from text_recognizer.networks.util import activation_function -class ConvBlock(nn.Module): - """Basic UNet convolutional block.""" +class _ConvBlock(nn.Module): + """Modified UNet convolutional block with dilation.""" - def __init__(self, channels: List[int], activation: str) -> None: + def __init__( + self, + channels: List[int], + activation: str, + num_groups: int, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, + ) -> None: super().__init__() self.channels = channels + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.num_groups = num_groups self.activation = activation_function(activation) self.block = self._configure_block() + self.residual_conv = nn.Sequential( + nn.Conv2d( + self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 + ), + self.activation, + ) def _configure_block(self) -> nn.Sequential: block = [] for i in range(len(self.channels) - 1): block += [ + nn.Dropout(p=self.dropout_rate), + nn.GroupNorm(self.num_groups, self.channels[i]), + self.activation, nn.Conv2d( - self.channels[i], self.channels[i + 1], kernel_size=3, padding=1 + self.channels[i], + self.channels[i + 1], + kernel_size=self.kernel_size, + padding=self.padding, + stride=1, + dilation=self.dilation, ), - nn.BatchNorm2d(self.channels[i + 1]), - self.activation, ] return nn.Sequential(*block) def forward(self, x: Tensor) -> Tensor: """Apply the convolutional block.""" - return self.block(x) + residual = self.residual_conv(x) + return self.block(x) + residual -class DownSamplingBlock(nn.Module): +class _DownSamplingBlock(nn.Module): """Basic down sampling block.""" def __init__( self, channels: List[int], activation: str, + num_groups: int, pooling_kernel: Union[int, bool] = 2, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, ) -> None: super().__init__() - self.conv_block = ConvBlock(channels, activation) + self.conv_block = _ConvBlock( + channels, + activation, + num_groups, + dropout_rate, + kernel_size, + dilation, + padding, + ) self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Return the convolutional block output and a down sampled tensor.""" x = self.conv_block(x) - if self.down_sampling is not None: - x_down = self.down_sampling(x) - else: - x_down = None + x_down = self.down_sampling(x) if self.down_sampling is not None else x + return x_down, x -class UpSamplingBlock(nn.Module): +class _UpSamplingBlock(nn.Module): """The upsampling block of the UNet.""" def __init__( - self, channels: List[int], activation: str, scale_factor: int = 2 + self, + channels: List[int], + activation: str, + num_groups: int, + scale_factor: int = 2, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, ) -> None: super().__init__() - self.conv_block = ConvBlock(channels, activation) + self.conv_block = _ConvBlock( + channels, + activation, + num_groups, + dropout_rate, + kernel_size, + dilation, + padding, + ) self.up_sampling = nn.Upsample( scale_factor=scale_factor, mode="bilinear", align_corners=True ) @@ -87,14 +141,43 @@ class UNet(nn.Module): base_channels: int = 64, num_classes: int = 3, depth: int = 4, - out_channels: int = 3, activation: str = "relu", + num_groups: int = 8, + dropout_rate: float = 0.1, pooling_kernel: int = 2, scale_factor: int = 2, + kernel_size: Optional[List[int]] = None, + dilation: Optional[List[int]] = None, + padding: Optional[List[int]] = None, ) -> None: super().__init__() self.depth = depth - channels = [1] + [base_channels * 2 ** i for i in range(depth)] + self.num_groups = num_groups + + if kernel_size is not None and dilation is not None and padding is not None: + if ( + len(kernel_size) != depth + and len(dilation) != depth + and len(padding) != depth + ): + raise RuntimeError( + "Length of convolutional parameters does not match the depth." + ) + self.kernel_size = kernel_size + self.padding = padding + self.dilation = dilation + + else: + self.kernel_size = [3] * depth + self.padding = [1] * depth + self.dilation = [1] * depth + + self.dropout_rate = dropout_rate + self.conv = nn.Conv2d( + in_channels, base_channels, kernel_size=3, stride=1, padding=1 + ) + + channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] self.encoder_blocks = self._configure_down_sampling_blocks( channels, activation, pooling_kernel ) @@ -110,49 +193,63 @@ class UNet(nn.Module): blocks = nn.ModuleList([]) for i in range(len(channels) - 1): pooling_kernel = pooling_kernel if i < self.depth - 1 else False + dropout_rate = self.dropout_rate if i < 0 else 0 blocks += [ - DownSamplingBlock( + _DownSamplingBlock( [channels[i], channels[i + 1], channels[i + 1]], activation, + self.num_groups, pooling_kernel, + dropout_rate, + self.kernel_size[i], + self.dilation[i], + self.padding[i], ) ] return blocks def _configure_up_sampling_blocks( - self, - channels: List[int], - activation: str, - scale_factor: int, + self, channels: List[int], activation: str, scale_factor: int, ) -> nn.ModuleList: channels.reverse() + self.kernel_size.reverse() + self.dilation.reverse() + self.padding.reverse() return nn.ModuleList( [ - UpSamplingBlock( + _UpSamplingBlock( [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], activation, + self.num_groups, scale_factor, + self.dropout_rate, + self.kernel_size[i], + self.dilation[i], + self.padding[i], ) for i in range(len(channels) - 2) ] ) - def encode(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]: + def _encode(self, x: Tensor) -> List[Tensor]: x_skips = [] for block in self.encoder_blocks: x, x_skip = block(x) - if x_skip is not None: - x_skips.append(x_skip) - return x, x_skips + x_skips.append(x_skip) + return x_skips - def decode(self, x: Tensor, x_skips: List[Tensor]) -> Tensor: + def _decode(self, x_skips: List[Tensor]) -> Tensor: x = x_skips[-1] for i, block in enumerate(self.decoder_blocks): x = block(x, x_skips[-(i + 2)]) return x def forward(self, x: Tensor) -> Tensor: - x, x_skips = self.encode(x) - x = self.decode(x, x_skips) + """Forward pass with the UNet model.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + x = self.conv(x) + x_skips = self._encode(x) + x = self._decode(x_skips) return self.head(x) diff --git a/src/text_recognizer/paragraph_text_recognizer.py b/src/text_recognizer/paragraph_text_recognizer.py new file mode 100644 index 0000000..aa39662 --- /dev/null +++ b/src/text_recognizer/paragraph_text_recognizer.py @@ -0,0 +1,153 @@ +"""Full model. + +Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the +each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text +in each region. +""" +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np +import torch + +from text_recognizer.models import SegmentationModel, TransformerModel +from text_recognizer.util import read_image + + +class ParagraphTextRecognizor: + """Given an image of a single handwritten character, recognizes it.""" + + def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None: + self._line_predictor = TransformerModel(**line_predictor_args) + self._line_detector = SegmentationModel(**line_detector_args) + self._line_detector.eval() + self._line_predictor.eval() + + def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple: + """Takes an image and returns all text within it.""" + image = ( + read_image(image_or_filename) + if isinstance(image_or_filename, str) + else image_or_filename + ) + + line_region_crops = self._get_line_region_crops(image) + processed_line_region_crops = [ + self._process_image_for_line_predictor(image=crop) + for crop in line_region_crops + ] + line_region_strings = [ + self.line_predictor_model.predict_on_image(crop)[0] + for crop in processed_line_region_crops + ] + + return " ".join(line_region_strings), line_region_crops + + def _get_line_region_crops( + self, image: np.ndarray, min_crop_len_factor: float = 0.02 + ) -> List[np.ndarray]: + """Returns all the crops of text lines in a square image.""" + processed_image, scale_down_factor = self._process_image_for_line_detector( + image + ) + line_segmentation = self._line_detector.predict_on_image(processed_image) + bounding_boxes = _find_line_bounding_boxes(line_segmentation) + + bounding_boxes = (bounding_boxes * scale_down_factor).astype(int) + + min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1])) + line_region_crops = [ + image[y : y + h, x : x + w] + for x, y, w, h in bounding_boxes + if w >= min_crop_len and h >= min_crop_len + ] + return line_region_crops + + def _process_image_for_line_detector( + self, image: np.ndarray + ) -> Tuple[np.ndarray, float]: + """Convert uint8 image to float image with black background with shape self._line_detector.image_shape.""" + resized_image, scale_down_factor = _resize_image_for_line_detector( + image=image, max_shape=self._line_detector.image_shape + ) + resized_image = (1.0 - resized_image / 255).astype("float32") + return resized_image, scale_down_factor + + def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray: + """Preprocessing of image before feeding it to the LinePrediction model. + + Convert uint8 image to float image with black background with shape + self._line_predictor.image_shape while maintaining the image aspect ratio. + + Args: + image (np.ndarray): Crop of text line. + + Returns: + np.ndarray: Processed crop for feeding line predictor. + """ + expected_shape = self._line_detector.image_shape + scale_factor = (np.array(expected_shape) / np.array(image.shape)).min() + scaled_image = cv2.resize( + image, + dsize=None, + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_AREA, + ) + + pad_with = ( + (0, expected_shape[0] - scaled_image.shape[0]), + (0, expected_shape[1] - scaled_image.shape[1]), + ) + + padded_image = np.pad( + scaled_image, pad_with=pad_with, mode="constant", constant_values=255 + ) + return 1 - padded_image / 255 + + +def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray: + """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels.""" + + def _find_line_bounding_boxes_in_channel( + line_segmentation_channel: np.ndarray, + ) -> np.ndarray: + line_segmentation_image = cv2.dilate( + line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1 + ) + line_activation_image = (line_segmentation_image * 255).astype("uint8") + line_activation_image = cv2.threshold( + line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU + )[1] + + bounding_cnts, _ = cv2.findContours( + line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts]) + + bounding_boxes = np.concatenate( + [ + _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i]) + for i in [1, 2] + ], + axis=0, + ) + + return bounding_boxes[np.argsort(bounding_boxes[:, 1])] + + +def _resize_image_for_line_detector( + image: np.ndarray, max_shape: Tuple[int, int] +) -> Tuple[np.ndarray, float]: + """Resize the image to less than the max_shape while maintaining the aspect ratio.""" + scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape)) + if scale_down_factor == 1: + return image.copy(), scale_down_factor + resize_image = cv2.resize( + image, + dsize=None, + fx=1 / scale_down_factor, + fy=1 / scale_down_factor, + interpolation=cv2.INTER_AREA, + ) + return resize_image, scale_down_factor diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg new file mode 100644 index 0000000..d9753b6 Binary files /dev/null and b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg differ diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/src/text_recognizer/tests/test_paragraph_text_recognizer.py new file mode 100644 index 0000000..3e280b9 --- /dev/null +++ b/src/text_recognizer/tests/test_paragraph_text_recognizer.py @@ -0,0 +1,37 @@ +"""Test for ParagraphTextRecognizer class.""" +import os +from pathlib import Path +import unittest + +from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" + +# Prevent using GPU. +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestParagraphTextRecognizor(unittest.TestCase): + """Test that it can take non-square images of max dimension larger than 256px.""" + + def test_filename(self) -> None: + """Test model on support image.""" + line_predictor_args = { + "dataset": "EmnistLineDataset", + "network_fn": "CNNTransformer", + } + line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} + model = ParagraphTextRecognizor( + line_predictor_args=line_predictor_args, + line_detector_args=line_detector_args, + ) + num_text_lines_by_name = {"a01-000u-cropped": 7} + for filename in (SUPPORT_DIRNAME).glob("*.jpg"): + full_image = util.read_image(str(filename), grayscale=True) + predicted_text, line_region_crops = model.predict(full_image) + print(predicted_text) + self.assertTrue( + len(line_region_crops), num_text_lines_by_name[filename.stem] + ) diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py index 6c07c60..b431e22 100644 --- a/src/text_recognizer/util.py +++ b/src/text_recognizer/util.py @@ -21,20 +21,21 @@ def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarr return cv2.imdecode(image_array, imread_flag) else: raise ValueError( - "Url does not start with http, therfore not safe to open..." + "Url does not start with http, therefore not safe to open..." ) from None imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR local_file = os.path.exists(image_uri) - try: - image = None - if local_file: - image = read_image_from_filename(image_uri, imread_flag) - else: - image = read_image_from_url(image_uri, imread_flag) - assert image is not None - except Exception as e: - raise ValueError(f"Could not load image at {image_uri}: {e}") + image = None + + if local_file: + image = read_image_from_filename(image_uri, imread_flag) + else: + image = read_image_from_url(image_uri, imread_flag) + + if image is None: + raise ValueError(f"Could not load image at {image_uri}") + return image diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt new file mode 100644 index 0000000..d9ca01d Binary files /dev/null and b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt differ diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt new file mode 100644 index 0000000..0af0e57 Binary files /dev/null and b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt differ -- cgit v1.2.3-70-g09d2