summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
commit25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch)
tree526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/text_recognizer
parent5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff)
Segmentation working!
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py11
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py8
-rw-r--r--src/text_recognizer/datasets/transforms.py18
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/base.py11
-rw-r--r--src/text_recognizer/models/segmentation_model.py75
-rw-r--r--src/text_recognizer/models/transformer_model.py4
-rw-r--r--src/text_recognizer/networks/__init__.py4
-rw-r--r--src/text_recognizer/networks/beam.py83
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py19
-rw-r--r--src/text_recognizer/networks/fcn.py99
-rw-r--r--src/text_recognizer/networks/neural_machine_reader.py201
-rw-r--r--src/text_recognizer/networks/residual_network.py7
-rw-r--r--src/text_recognizer/networks/unet.py159
-rw-r--r--src/text_recognizer/paragraph_text_recognizer.py153
-rw-r--r--src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpgbin0 -> 14890 bytes
-rw-r--r--src/text_recognizer/tests/test_paragraph_text_recognizer.py37
-rw-r--r--src/text_recognizer/util.py21
-rw-r--r--src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.ptbin0 -> 8588813 bytes
-rw-r--r--src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.ptbin0 -> 92335101 bytes
20 files changed, 551 insertions, 361 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6871492..eddf341 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -10,6 +10,7 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
+import torch.nn.functional as F
from torchvision.transforms import ToTensor
from text_recognizer.datasets.dataset import Dataset
@@ -23,6 +24,8 @@ from text_recognizer.datasets.util import (
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+MAX_WIDTH = 952
+
class EmnistLinesDataset(Dataset):
"""Synthetic dataset of lines from the Brown corpus with Emnist characters."""
@@ -254,6 +257,14 @@ def construct_image_from_string(
for image in sampled_images:
concatenated_image[:, x : (x + width)] += image
x += next_overlap_width
+
+ if concatenated_image.shape[-1] > MAX_WIDTH:
+ concatenated_image = Tensor(concatenated_image).unsqueeze(0)
+ concatenated_image = F.interpolate(
+ concatenated_image, size=MAX_WIDTH, mode="nearest"
+ )
+ concatenated_image = concatenated_image.squeeze(0).numpy()
+
return np.minimum(255, concatenated_image)
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index c1e8fe2..8ba5142 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -1,4 +1,5 @@
"""IamParagraphsDataset class and functions for data processing."""
+import random
from typing import Callable, Dict, List, Optional, Tuple, Union
import click
@@ -71,13 +72,18 @@ class IamParagraphsDataset(Dataset):
data = self.data[index]
targets = self.targets[index]
+ seed = np.random.randint(SEED)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
if self.transform:
data = self.transform(data)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
if self.target_transform:
targets = self.target_transform(targets)
- return data, targets
+ return data, targets.long()
@property
def ids(self) -> Tensor:
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 1ec23dc..016ec80 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -4,7 +4,7 @@ from PIL import Image
import torch
from torch import Tensor
import torch.nn.functional as F
-from torchvision.transforms import Compose, RandomAffine, ToTensor
+from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor
from text_recognizer.datasets.util import EmnistMapper
@@ -77,3 +77,19 @@ class ApplyContrast:
"""Apply mask binary mask to input tensor."""
mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
return x * mask
+
+
+class Unsqueeze:
+ """Add a dimension to the tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Adds dim."""
+ return x.unsqueeze(0)
+
+
+class Squeeze:
+ """Removes the first dimension of a tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Removes first dim."""
+ return x.squeeze(0)
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index bf89404..a645cec 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,11 +2,13 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
+from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
__all__ = [
"CharacterModel",
"CRNNModel",
"Model",
+ "SegmentationModel",
"TransformerModel",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index d394b4c..f2cd4b8 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -159,7 +159,7 @@ class Model(ABC):
self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
self.test_dataset.load_or_generate_data()
- # Set the flag to true to disable ability to load data agian.
+ # Set the flag to true to disable ability to load data again.
self.data_prepared = True
def train_dataloader(self) -> DataLoader:
@@ -260,7 +260,7 @@ class Model(ABC):
@property
def mapping(self) -> Dict:
"""Returns the mapping between network output and Emnist character."""
- return self._mapper.mapping
+ return self._mapper.mapping if self._mapper is not None else None
def eval(self) -> None:
"""Sets the network to evaluation mode."""
@@ -341,7 +341,7 @@ class Model(ABC):
if input_shape is not None:
summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
- input_shape = (1,) + tuple(self._input_shape)
+ input_shape = tuple(self._input_shape)
summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -427,7 +427,7 @@ class Model(ABC):
)
shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> None:
+ def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -441,7 +441,8 @@ class Model(ABC):
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- self._network = network_fn(**self._network_args)
+ if network_fn is not None:
+ self._network = network_fn(**self._network_args)
self._network.load_state_dict(weights)
if "swa_network" in state_dict:
diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py
new file mode 100644
index 0000000..613108a
--- /dev/null
+++ b/src/text_recognizer/models/segmentation_model.py
@@ -0,0 +1,75 @@
+"""Segmentation model for detecting and segmenting lines."""
+from typing import Callable, Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.models.base import Model
+
+
+class SegmentationModel(Model):
+ """Model for segmenting lines in an image."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype is np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype is torch.uint8 or image.dtype is torch.int64:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ if not torch.is_tensor(image):
+ image = Tensor(image)
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ logits = self.forward(image)
+
+ segmentation_mask = torch.argmax(logits, dim=1)
+
+ return segmentation_mask
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 968a047..a912122 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -18,8 +18,8 @@ class TransformerModel(Model):
def __init__(
self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
+ network_fn: str,
+ dataset: str,
network_args: Optional[Dict] = None,
dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 1635039..f958672 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -3,11 +3,13 @@ from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
+from .fcn import FCN
from .lenet import LeNet
from .metrics import accuracy, accuracy_ignore_pad, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .transformer import Transformer
+from .unet import UNet
from .util import sliding_window
from .wide_resnet import WideResidualNetwork
@@ -18,12 +20,14 @@ __all__ = [
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
"DenseNet",
+ "FCN",
"greedy_decoder",
"MLP",
"LeNet",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
+ "UNet",
"Transformer",
"wer",
"WideResidualNetwork",
diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py
new file mode 100644
index 0000000..dccccdb
--- /dev/null
+++ b/src/text_recognizer/networks/beam.py
@@ -0,0 +1,83 @@
+"""Implementation of beam search decoder for a sequence to sequence network.
+
+Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
+
+"""
+# from typing import List
+# from Queue import PriorityQueue
+
+# from loguru import logger
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import torch.nn.functional as F
+
+
+# class Node:
+# def __init__(
+# self, parent: Node, target_index: int, log_prob: Tensor, length: int
+# ) -> None:
+# self.parent = parent
+# self.target_index = target_index
+# self.log_prob = log_prob
+# self.length = length
+# self.reward = 0.0
+
+# def eval(self, alpha: float = 1.0) -> Tensor:
+# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward
+
+
+# @torch.no_grad()
+# def beam_decoder(
+# network, mapper, device, memory: Tensor = None, max_len: int = 97,
+# ) -> Tensor:
+# beam_width = 10
+# topk = 1 # How many sentences to generate.
+
+# trg_indices = [mapper(mapper.init_token)]
+
+# end_nodes = []
+
+# node = Node(None, trg_indices, 0, 1)
+# nodes = PriorityQueue()
+
+# nodes.put((node.eval(), node))
+# q_size = 1
+
+# # Beam search
+# for _ in range(max_len):
+# if q_size > 2000:
+# logger.warning("Could not decoder input")
+# break
+
+# # Fetch the best node.
+# score, n = nodes.get()
+# decoder_input = n.target_index
+
+# if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
+# end_nodes.append((score, n))
+
+# # If we reached the maximum number of sentences required.
+# if len(end_nodes) >= 1:
+# break
+# else:
+# continue
+
+# # Forward pass with transformer.
+# trg = torch.tensor(trg_indices, device=device)[None, :].long()
+# trg = network.target_embedding(trg)
+# logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
+# log_prob = F.log_softmax(logits, dim=2)
+
+# log_prob, indices = torch.topk(log_prob, beam_width)
+
+# for new_k in range(beam_width):
+# # TODO: continue from here
+# token_index = indices[0][new_k].view(1, -1)
+# log_p = log_prob[0][new_k].item()
+
+# node = Node()
+
+# pass
+
+# pass
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 16c7a41..b2b74b3 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -88,10 +88,14 @@ class CNNTransformer(nn.Module):
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
src = self.backbone(src)
- src = rearrange(src, "b c h w -> b w c h")
+
if self.adaptive_pool is not None:
+ src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
- src = src.squeeze(3)
+ src = src.squeeze(3)
+ else:
+ src = rearrange(src, "b c h w -> b (w h) c")
+
src = self.position_encoding(src)
return src
@@ -110,12 +114,17 @@ class CNNTransformer(nn.Module):
trg = self.position_encoding(trg)
return trg
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
+ def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
trg_mask = self._create_trg_mask(trg)
trg = self.target_embedding(trg)
out = self.transformer(h, trg, trg_mask=trg_mask)
logits = self.head(out)
return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ h = self.extract_image_features(x)
+ logits = self.decode_image_features(h, trg)
+ return logits
diff --git a/src/text_recognizer/networks/fcn.py b/src/text_recognizer/networks/fcn.py
deleted file mode 100644
index f9c4fd4..0000000
--- a/src/text_recognizer/networks/fcn.py
+++ /dev/null
@@ -1,99 +0,0 @@
-"""Fully Convolutional Network (FCN) with dilated kernels for global context."""
-from typing import List, Tuple, Type
-import torch
-from torch import nn
-from torch import Tensor
-
-
-from text_recognizer.networks.util import activation_function
-
-
-class _DilatedBlock(nn.Module):
- def __init__(
- self,
- channels: List[int],
- kernel_sizes: List[int],
- dilations: List[int],
- paddings: List[int],
- activation_fn: Type[nn.Module],
- ) -> None:
- super().__init__()
- self.dilation_conv = nn.Sequential(
- nn.Conv2d(
- in_channels=channels[0],
- out_channels=channels[1],
- kernel_size=kernel_sizes[0],
- stride=1,
- dilation=dilations[0],
- padding=paddings[0],
- ),
- nn.Conv2d(
- in_channels=channels[1],
- out_channels=channels[1] // 2,
- kernel_size=kernel_sizes[1],
- stride=1,
- dilation=dilations[1],
- padding=paddings[1],
- ),
- )
- self.activation_fn = activation_fn
-
- self.conv = nn.Conv2d(
- in_channels=channels[0],
- out_channels=channels[1] // 2,
- kernel_size=1,
- dilation=1,
- stride=1,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- residual = self.conv(x)
- x = self.dilation_conv(x)
- x = torch.cat((x, residual), dim=1)
- return self.activation_fn(x)
-
-
-class FCN(nn.Module):
- def __init__(
- self,
- in_channels: int,
- base_channels: int,
- out_channels: int,
- kernel_size: int,
- dilations: Tuple[int] = (3, 7),
- paddings: Tuple[int] = (9, 21),
- num_blocks: int = 14,
- activation: str = "elu",
- ) -> None:
- super().__init__()
- self.kernel_sizes = [kernel_size] * num_blocks
- self.channels = [in_channels] + [base_channels] * (num_blocks - 1)
- self.out_channels = out_channels
- self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * (
- num_blocks // 2
- )
- self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * (
- num_blocks // 2
- )
- self.activation_fn = activation_function(activation)
- self.fcn = self._configure_fcn()
-
- def _configure_fcn(self) -> nn.Sequential:
- layers = []
- for i in range(0, len(self.channels), 2):
- layers.append(
- _DilatedBlock(
- self.channels[i : i + 2],
- self.kernel_sizes[i : i + 2],
- self.dilations[i : i + 2],
- self.paddings[i : i + 2],
- self.activation_fn,
- )
- )
- layers.append(
- nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1)
- )
- return nn.Sequential(*layers)
-
- def forward(self, x: Tensor) -> Tensor:
- return self.fcn(x)
diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py
deleted file mode 100644
index 7f8c49b..0000000
--- a/src/text_recognizer/networks/neural_machine_reader.py
+++ /dev/null
@@ -1,201 +0,0 @@
-"""Sequence to sequence network with RNN cells."""
-# from typing import Dict, Optional, Tuple
-
-# from einops import rearrange
-# from einops.layers.torch import Rearrange
-# import torch
-# from torch import nn
-# from torch import Tensor
-
-# from text_recognizer.networks.util import configure_backbone
-
-
-# class Encoder(nn.Module):
-# def __init__(
-# self,
-# embedding_dim: int,
-# encoder_dim: int,
-# decoder_dim: int,
-# dropout_rate: float = 0.1,
-# ) -> None:
-# super().__init__()
-# self.rnn = nn.GRU(
-# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True
-# )
-# self.fc = nn.Sequential(
-# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh()
-# )
-# self.dropout = nn.Dropout(p=dropout_rate)
-
-# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
-# """Encodes a sequence of tensors with a bidirectional GRU.
-
-# Args:
-# x (Tensor): A input sequence.
-
-# Shape:
-# - x: :math:`(T, N, E)`.
-# - output[0]: :math:`(T, N, 2 * E)`.
-# - output[1]: :math:`(T, N, D)`.
-
-# where T is the sequence length, N is the batch size, E is the
-# embedding/encoder dimension, and D is the decoder dimension.
-
-# Returns:
-# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the
-# encoder.
-
-# """
-
-# output, hidden = self.rnn(x)
-
-# # Get the hidden state from the forward and backward rnn.
-# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
-
-# # Apply fully connected layer and tanh activation.
-# hidden_state = self.fc(hidden_state)
-
-# return output, hidden_state
-
-
-# class Attention(nn.Module):
-# def __init__(self, encoder_dim: int, decoder_dim: int) -> None:
-# super().__init__()
-# self.atten = nn.Linear(
-# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim
-# )
-# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False)
-
-# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor:
-# """Short summary.
-
-# Args:
-# hidden_state (Tensor): Description of parameter `h`.
-# encoder_outputs (Tensor): Description of parameter `enc_out`.
-
-# Shape:
-# - x: :math:`(T, N, E)`.
-# - output[0]: :math:`(T, N, 2 * E)`.
-# - output[1]: :math:`(T, N, D)`.
-
-# where T is the sequence length, N is the batch size, E is the
-# embedding/encoder dimension, and D is the decoder dimension.
-
-# Returns:
-# Tensor: Description of returned object.
-
-# """
-# t, b = enc_out.shape[:2]
-# # repeat decoder hidden state src_len times
-# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1)
-
-# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2")
-
-# # Calculate the energy between the decoders previous hidden state and the
-# # encoders hidden states.
-# energy = torch.tanh(
-# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2))
-# )
-
-# attention = self.value(energy).squeeze(2)
-
-# # Apply softmax on the attention to squeeze it between 0 and 1.
-# attention = F.softmax(attention, dim=1)
-
-# return attention
-
-
-# class Decoder(nn.Module):
-# def __init__(
-# self,
-# embedding_dim: int,
-# encoder_dim: int,
-# decoder_dim: int,
-# output_dim: int,
-# dropout_rate: float = 0.1,
-# ) -> None:
-# super().__init__()
-# self.output_dim = output_dim
-# self.embedding = nn.Embedding(output_dim, embedding_dim)
-# self.attention = Attention(encoder_dim, decoder_dim)
-# self.rnn = nn.GRU(
-# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim
-# )
-
-# self.head = nn.Linear(
-# in_features=2 * encoder_dim + embedding_dim + decoder_dim,
-# out_features=output_dim,
-# )
-# self.dropout = nn.Dropout(p=dropout_rate)
-
-# def forward(
-# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor
-# ) -> Tensor:
-# # input = [batch size]
-# # hidden = [batch size, dec hid dim]
-# # encoder_outputs = [src len, batch size, enc hid dim * 2]
-# trg = trg.unsqueeze(0)
-# trg_embedded = self.dropout(self.embedding(trg))
-
-# a = self.attention(hidden_state, encoder_outputs)
-
-# weighted = torch.bmm(a, encoder_outputs)
-
-# # Permutate the tensor.
-# weighted = rearrange(weighted, "b a e2 -> a b e2")
-
-# rnn_input = torch.cat((trg_embedded, weighted), dim=2)
-
-# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
-
-# # seq len, n layers and n directions will always be 1 in this decoder, therefore:
-# # output = [1, batch size, dec hid dim]
-# # hidden = [1, batch size, dec hid dim]
-# # this also means that output == hidden
-# assert (output == hidden).all()
-
-# trg_embedded = trg_embedded.squeeze(0)
-# output = output.squeeze(0)
-# weighted = weighted.squeeze(0)
-
-# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1))
-
-# # prediction = [batch size, output dim]
-
-# return logits, hidden.squeeze(0)
-
-
-# class NeuralMachineReader(nn.Module):
-# def __init__(
-# self,
-# embedding_dim: int,
-# encoder_dim: int,
-# decoder_dim: int,
-# output_dim: int,
-# backbone: Optional[str] = None,
-# backbone_args: Optional[Dict] = None,
-# adaptive_pool_dim: Tuple = (None, 1),
-# dropout_rate: float = 0.1,
-# teacher_forcing_ratio: float = 0.5,
-# ) -> None:
-# super().__init__()
-
-# self.backbone = configure_backbone(backbone, backbone_args)
-# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim))
-
-# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate)
-# self.decoder = Decoder(
-# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate
-# )
-# self.teacher_forcing_ratio = teacher_forcing_ratio
-
-# def extract_image_features(self, x: Tensor) -> Tensor:
-# x = self.backbone(x)
-# x = rearrange(x, "b c h w -> b w c h")
-# x = self.adaptive_pool(x)
-# x = x.squeeze(3)
-
-# def forward(self, x: Tensor, trg: Tensor) -> Tensor:
-# # x = [batch size, height, width]
-# # trg = [trg len, batch size]
-# z = self.extract_image_features(x)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 6405192..e397224 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,7 +7,6 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.stn import SpatialTransformerNetwork
from text_recognizer.networks.util import activation_function
@@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module):
activation: str = "relu",
block: Type[nn.Module] = BasicBlock,
levels: int = 1,
- stn: bool = False,
*args,
**kwargs
) -> None:
super().__init__()
- self.stn = SpatialTransformerNetwork() if stn else None
self.block_sizes = (
block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
)
@@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module):
),
nn.BatchNorm2d(self.block_sizes[0]),
activation_function(self.activation),
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+ # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
)
self.blocks = self._configure_blocks(block)
@@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module):
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
- if self.stn is not None:
- x = self.stn(x)
x = self.gate(x)
x = self.blocks(x)
return x
diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py
index 51f242a..510910f 100644
--- a/src/text_recognizer/networks/unet.py
+++ b/src/text_recognizer/networks/unet.py
@@ -8,64 +8,118 @@ from torch import Tensor
from text_recognizer.networks.util import activation_function
-class ConvBlock(nn.Module):
- """Basic UNet convolutional block."""
+class _ConvBlock(nn.Module):
+ """Modified UNet convolutional block with dilation."""
- def __init__(self, channels: List[int], activation: str) -> None:
+ def __init__(
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
+ ) -> None:
super().__init__()
self.channels = channels
+ self.dropout_rate = dropout_rate
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.padding = padding
+ self.num_groups = num_groups
self.activation = activation_function(activation)
self.block = self._configure_block()
+ self.residual_conv = nn.Sequential(
+ nn.Conv2d(
+ self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1
+ ),
+ self.activation,
+ )
def _configure_block(self) -> nn.Sequential:
block = []
for i in range(len(self.channels) - 1):
block += [
+ nn.Dropout(p=self.dropout_rate),
+ nn.GroupNorm(self.num_groups, self.channels[i]),
+ self.activation,
nn.Conv2d(
- self.channels[i], self.channels[i + 1], kernel_size=3, padding=1
+ self.channels[i],
+ self.channels[i + 1],
+ kernel_size=self.kernel_size,
+ padding=self.padding,
+ stride=1,
+ dilation=self.dilation,
),
- nn.BatchNorm2d(self.channels[i + 1]),
- self.activation,
]
return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tensor:
"""Apply the convolutional block."""
- return self.block(x)
+ residual = self.residual_conv(x)
+ return self.block(x) + residual
-class DownSamplingBlock(nn.Module):
+class _DownSamplingBlock(nn.Module):
"""Basic down sampling block."""
def __init__(
self,
channels: List[int],
activation: str,
+ num_groups: int,
pooling_kernel: Union[int, bool] = 2,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
) -> None:
super().__init__()
- self.conv_block = ConvBlock(channels, activation)
+ self.conv_block = _ConvBlock(
+ channels,
+ activation,
+ num_groups,
+ dropout_rate,
+ kernel_size,
+ dilation,
+ padding,
+ )
self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Return the convolutional block output and a down sampled tensor."""
x = self.conv_block(x)
- if self.down_sampling is not None:
- x_down = self.down_sampling(x)
- else:
- x_down = None
+ x_down = self.down_sampling(x) if self.down_sampling is not None else x
+
return x_down, x
-class UpSamplingBlock(nn.Module):
+class _UpSamplingBlock(nn.Module):
"""The upsampling block of the UNet."""
def __init__(
- self, channels: List[int], activation: str, scale_factor: int = 2
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ scale_factor: int = 2,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
) -> None:
super().__init__()
- self.conv_block = ConvBlock(channels, activation)
+ self.conv_block = _ConvBlock(
+ channels,
+ activation,
+ num_groups,
+ dropout_rate,
+ kernel_size,
+ dilation,
+ padding,
+ )
self.up_sampling = nn.Upsample(
scale_factor=scale_factor, mode="bilinear", align_corners=True
)
@@ -87,14 +141,43 @@ class UNet(nn.Module):
base_channels: int = 64,
num_classes: int = 3,
depth: int = 4,
- out_channels: int = 3,
activation: str = "relu",
+ num_groups: int = 8,
+ dropout_rate: float = 0.1,
pooling_kernel: int = 2,
scale_factor: int = 2,
+ kernel_size: Optional[List[int]] = None,
+ dilation: Optional[List[int]] = None,
+ padding: Optional[List[int]] = None,
) -> None:
super().__init__()
self.depth = depth
- channels = [1] + [base_channels * 2 ** i for i in range(depth)]
+ self.num_groups = num_groups
+
+ if kernel_size is not None and dilation is not None and padding is not None:
+ if (
+ len(kernel_size) != depth
+ and len(dilation) != depth
+ and len(padding) != depth
+ ):
+ raise RuntimeError(
+ "Length of convolutional parameters does not match the depth."
+ )
+ self.kernel_size = kernel_size
+ self.padding = padding
+ self.dilation = dilation
+
+ else:
+ self.kernel_size = [3] * depth
+ self.padding = [1] * depth
+ self.dilation = [1] * depth
+
+ self.dropout_rate = dropout_rate
+ self.conv = nn.Conv2d(
+ in_channels, base_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)]
self.encoder_blocks = self._configure_down_sampling_blocks(
channels, activation, pooling_kernel
)
@@ -110,49 +193,63 @@ class UNet(nn.Module):
blocks = nn.ModuleList([])
for i in range(len(channels) - 1):
pooling_kernel = pooling_kernel if i < self.depth - 1 else False
+ dropout_rate = self.dropout_rate if i < 0 else 0
blocks += [
- DownSamplingBlock(
+ _DownSamplingBlock(
[channels[i], channels[i + 1], channels[i + 1]],
activation,
+ self.num_groups,
pooling_kernel,
+ dropout_rate,
+ self.kernel_size[i],
+ self.dilation[i],
+ self.padding[i],
)
]
return blocks
def _configure_up_sampling_blocks(
- self,
- channels: List[int],
- activation: str,
- scale_factor: int,
+ self, channels: List[int], activation: str, scale_factor: int,
) -> nn.ModuleList:
channels.reverse()
+ self.kernel_size.reverse()
+ self.dilation.reverse()
+ self.padding.reverse()
return nn.ModuleList(
[
- UpSamplingBlock(
+ _UpSamplingBlock(
[channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
activation,
+ self.num_groups,
scale_factor,
+ self.dropout_rate,
+ self.kernel_size[i],
+ self.dilation[i],
+ self.padding[i],
)
for i in range(len(channels) - 2)
]
)
- def encode(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
+ def _encode(self, x: Tensor) -> List[Tensor]:
x_skips = []
for block in self.encoder_blocks:
x, x_skip = block(x)
- if x_skip is not None:
- x_skips.append(x_skip)
- return x, x_skips
+ x_skips.append(x_skip)
+ return x_skips
- def decode(self, x: Tensor, x_skips: List[Tensor]) -> Tensor:
+ def _decode(self, x_skips: List[Tensor]) -> Tensor:
x = x_skips[-1]
for i, block in enumerate(self.decoder_blocks):
x = block(x, x_skips[-(i + 2)])
return x
def forward(self, x: Tensor) -> Tensor:
- x, x_skips = self.encode(x)
- x = self.decode(x, x_skips)
+ """Forward pass with the UNet model."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ x = self.conv(x)
+ x_skips = self._encode(x)
+ x = self._decode(x_skips)
return self.head(x)
diff --git a/src/text_recognizer/paragraph_text_recognizer.py b/src/text_recognizer/paragraph_text_recognizer.py
new file mode 100644
index 0000000..aa39662
--- /dev/null
+++ b/src/text_recognizer/paragraph_text_recognizer.py
@@ -0,0 +1,153 @@
+"""Full model.
+
+Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the
+each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text
+in each region.
+"""
+from typing import Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+import torch
+
+from text_recognizer.models import SegmentationModel, TransformerModel
+from text_recognizer.util import read_image
+
+
+class ParagraphTextRecognizor:
+ """Given an image of a single handwritten character, recognizes it."""
+
+ def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None:
+ self._line_predictor = TransformerModel(**line_predictor_args)
+ self._line_detector = SegmentationModel(**line_detector_args)
+ self._line_detector.eval()
+ self._line_predictor.eval()
+
+ def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple:
+ """Takes an image and returns all text within it."""
+ image = (
+ read_image(image_or_filename)
+ if isinstance(image_or_filename, str)
+ else image_or_filename
+ )
+
+ line_region_crops = self._get_line_region_crops(image)
+ processed_line_region_crops = [
+ self._process_image_for_line_predictor(image=crop)
+ for crop in line_region_crops
+ ]
+ line_region_strings = [
+ self.line_predictor_model.predict_on_image(crop)[0]
+ for crop in processed_line_region_crops
+ ]
+
+ return " ".join(line_region_strings), line_region_crops
+
+ def _get_line_region_crops(
+ self, image: np.ndarray, min_crop_len_factor: float = 0.02
+ ) -> List[np.ndarray]:
+ """Returns all the crops of text lines in a square image."""
+ processed_image, scale_down_factor = self._process_image_for_line_detector(
+ image
+ )
+ line_segmentation = self._line_detector.predict_on_image(processed_image)
+ bounding_boxes = _find_line_bounding_boxes(line_segmentation)
+
+ bounding_boxes = (bounding_boxes * scale_down_factor).astype(int)
+
+ min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1]))
+ line_region_crops = [
+ image[y : y + h, x : x + w]
+ for x, y, w, h in bounding_boxes
+ if w >= min_crop_len and h >= min_crop_len
+ ]
+ return line_region_crops
+
+ def _process_image_for_line_detector(
+ self, image: np.ndarray
+ ) -> Tuple[np.ndarray, float]:
+ """Convert uint8 image to float image with black background with shape self._line_detector.image_shape."""
+ resized_image, scale_down_factor = _resize_image_for_line_detector(
+ image=image, max_shape=self._line_detector.image_shape
+ )
+ resized_image = (1.0 - resized_image / 255).astype("float32")
+ return resized_image, scale_down_factor
+
+ def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray:
+ """Preprocessing of image before feeding it to the LinePrediction model.
+
+ Convert uint8 image to float image with black background with shape
+ self._line_predictor.image_shape while maintaining the image aspect ratio.
+
+ Args:
+ image (np.ndarray): Crop of text line.
+
+ Returns:
+ np.ndarray: Processed crop for feeding line predictor.
+ """
+ expected_shape = self._line_detector.image_shape
+ scale_factor = (np.array(expected_shape) / np.array(image.shape)).min()
+ scaled_image = cv2.resize(
+ image,
+ dsize=None,
+ fx=scale_factor,
+ fy=scale_factor,
+ interpolation=cv2.INTER_AREA,
+ )
+
+ pad_with = (
+ (0, expected_shape[0] - scaled_image.shape[0]),
+ (0, expected_shape[1] - scaled_image.shape[1]),
+ )
+
+ padded_image = np.pad(
+ scaled_image, pad_with=pad_with, mode="constant", constant_values=255
+ )
+ return 1 - padded_image / 255
+
+
+def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray:
+ """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels."""
+
+ def _find_line_bounding_boxes_in_channel(
+ line_segmentation_channel: np.ndarray,
+ ) -> np.ndarray:
+ line_segmentation_image = cv2.dilate(
+ line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1
+ )
+ line_activation_image = (line_segmentation_image * 255).astype("uint8")
+ line_activation_image = cv2.threshold(
+ line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU
+ )[1]
+
+ bounding_cnts, _ = cv2.findContours(
+ line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
+ return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts])
+
+ bounding_boxes = np.concatenate(
+ [
+ _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i])
+ for i in [1, 2]
+ ],
+ axis=0,
+ )
+
+ return bounding_boxes[np.argsort(bounding_boxes[:, 1])]
+
+
+def _resize_image_for_line_detector(
+ image: np.ndarray, max_shape: Tuple[int, int]
+) -> Tuple[np.ndarray, float]:
+ """Resize the image to less than the max_shape while maintaining the aspect ratio."""
+ scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape))
+ if scale_down_factor == 1:
+ return image.copy(), scale_down_factor
+ resize_image = cv2.resize(
+ image,
+ dsize=None,
+ fx=1 / scale_down_factor,
+ fy=1 / scale_down_factor,
+ interpolation=cv2.INTER_AREA,
+ )
+ return resize_image, scale_down_factor
diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
new file mode 100644
index 0000000..d9753b6
--- /dev/null
+++ b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
Binary files differ
diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/src/text_recognizer/tests/test_paragraph_text_recognizer.py
new file mode 100644
index 0000000..3e280b9
--- /dev/null
+++ b/src/text_recognizer/tests/test_paragraph_text_recognizer.py
@@ -0,0 +1,37 @@
+"""Test for ParagraphTextRecognizer class."""
+import os
+from pathlib import Path
+import unittest
+
+from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph"
+
+# Prevent using GPU.
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestParagraphTextRecognizor(unittest.TestCase):
+ """Test that it can take non-square images of max dimension larger than 256px."""
+
+ def test_filename(self) -> None:
+ """Test model on support image."""
+ line_predictor_args = {
+ "dataset": "EmnistLineDataset",
+ "network_fn": "CNNTransformer",
+ }
+ line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"}
+ model = ParagraphTextRecognizor(
+ line_predictor_args=line_predictor_args,
+ line_detector_args=line_detector_args,
+ )
+ num_text_lines_by_name = {"a01-000u-cropped": 7}
+ for filename in (SUPPORT_DIRNAME).glob("*.jpg"):
+ full_image = util.read_image(str(filename), grayscale=True)
+ predicted_text, line_region_crops = model.predict(full_image)
+ print(predicted_text)
+ self.assertTrue(
+ len(line_region_crops), num_text_lines_by_name[filename.stem]
+ )
diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py
index 6c07c60..b431e22 100644
--- a/src/text_recognizer/util.py
+++ b/src/text_recognizer/util.py
@@ -21,20 +21,21 @@ def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarr
return cv2.imdecode(image_array, imread_flag)
else:
raise ValueError(
- "Url does not start with http, therfore not safe to open..."
+ "Url does not start with http, therefore not safe to open..."
) from None
imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
local_file = os.path.exists(image_uri)
- try:
- image = None
- if local_file:
- image = read_image_from_filename(image_uri, imread_flag)
- else:
- image = read_image_from_url(image_uri, imread_flag)
- assert image is not None
- except Exception as e:
- raise ValueError(f"Could not load image at {image_uri}: {e}")
+ image = None
+
+ if local_file:
+ image = read_image_from_filename(image_uri, imread_flag)
+ else:
+ image = read_image_from_url(image_uri, imread_flag)
+
+ if image is None:
+ raise ValueError(f"Could not load image at {image_uri}")
+
return image
diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
new file mode 100644
index 0000000..d9ca01d
--- /dev/null
+++ b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
new file mode 100644
index 0000000..0af0e57
--- /dev/null
+++ b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
Binary files differ