From 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 7 Dec 2020 22:54:04 +0100 Subject: Segmentation working! --- src/text_recognizer/networks/__init__.py | 4 + src/text_recognizer/networks/beam.py | 83 +++++++++ src/text_recognizer/networks/cnn_transformer.py | 19 +- src/text_recognizer/networks/fcn.py | 99 ---------- .../networks/neural_machine_reader.py | 201 --------------------- src/text_recognizer/networks/residual_network.py | 7 +- src/text_recognizer/networks/unet.py | 159 ++++++++++++---- 7 files changed, 230 insertions(+), 342 deletions(-) create mode 100644 src/text_recognizer/networks/beam.py delete mode 100644 src/text_recognizer/networks/fcn.py delete mode 100644 src/text_recognizer/networks/neural_machine_reader.py (limited to 'src/text_recognizer/networks') diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 1635039..f958672 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -3,11 +3,13 @@ from .cnn_transformer import CNNTransformer from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet +from .fcn import FCN from .lenet import LeNet from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .transformer import Transformer +from .unet import UNet from .util import sliding_window from .wide_resnet import WideResidualNetwork @@ -18,12 +20,14 @@ __all__ = [ "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", + "FCN", "greedy_decoder", "MLP", "LeNet", "ResidualNetwork", "ResidualNetworkEncoder", "sliding_window", + "UNet", "Transformer", "wer", "WideResidualNetwork", diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py new file mode 100644 index 0000000..dccccdb --- /dev/null +++ b/src/text_recognizer/networks/beam.py @@ -0,0 +1,83 @@ +"""Implementation of beam search decoder for a sequence to sequence network. + +Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py + +""" +# from typing import List +# from Queue import PriorityQueue + +# from loguru import logger +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F + + +# class Node: +# def __init__( +# self, parent: Node, target_index: int, log_prob: Tensor, length: int +# ) -> None: +# self.parent = parent +# self.target_index = target_index +# self.log_prob = log_prob +# self.length = length +# self.reward = 0.0 + +# def eval(self, alpha: float = 1.0) -> Tensor: +# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward + + +# @torch.no_grad() +# def beam_decoder( +# network, mapper, device, memory: Tensor = None, max_len: int = 97, +# ) -> Tensor: +# beam_width = 10 +# topk = 1 # How many sentences to generate. + +# trg_indices = [mapper(mapper.init_token)] + +# end_nodes = [] + +# node = Node(None, trg_indices, 0, 1) +# nodes = PriorityQueue() + +# nodes.put((node.eval(), node)) +# q_size = 1 + +# # Beam search +# for _ in range(max_len): +# if q_size > 2000: +# logger.warning("Could not decoder input") +# break + +# # Fetch the best node. +# score, n = nodes.get() +# decoder_input = n.target_index + +# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: +# end_nodes.append((score, n)) + +# # If we reached the maximum number of sentences required. +# if len(end_nodes) >= 1: +# break +# else: +# continue + +# # Forward pass with transformer. +# trg = torch.tensor(trg_indices, device=device)[None, :].long() +# trg = network.target_embedding(trg) +# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) +# log_prob = F.log_softmax(logits, dim=2) + +# log_prob, indices = torch.topk(log_prob, beam_width) + +# for new_k in range(beam_width): +# # TODO: continue from here +# token_index = indices[0][new_k].view(1, -1) +# log_p = log_prob[0][new_k].item() + +# node = Node() + +# pass + +# pass diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 16c7a41..b2b74b3 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -88,10 +88,14 @@ class CNNTransformer(nn.Module): if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - src = rearrange(src, "b c h w -> b w c h") + if self.adaptive_pool is not None: + src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) - src = src.squeeze(3) + src = src.squeeze(3) + else: + src = rearrange(src, "b c h w -> b (w h) c") + src = self.position_encoding(src) return src @@ -110,12 +114,17 @@ class CNNTransformer(nn.Module): trg = self.position_encoding(trg) return trg - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) + def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" trg_mask = self._create_trg_mask(trg) trg = self.target_embedding(trg) out = self.transformer(h, trg, trg_mask=trg_mask) logits = self.head(out) return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.extract_image_features(x) + logits = self.decode_image_features(h, trg) + return logits diff --git a/src/text_recognizer/networks/fcn.py b/src/text_recognizer/networks/fcn.py deleted file mode 100644 index f9c4fd4..0000000 --- a/src/text_recognizer/networks/fcn.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Fully Convolutional Network (FCN) with dilated kernels for global context.""" -from typing import List, Tuple, Type -import torch -from torch import nn -from torch import Tensor - - -from text_recognizer.networks.util import activation_function - - -class _DilatedBlock(nn.Module): - def __init__( - self, - channels: List[int], - kernel_sizes: List[int], - dilations: List[int], - paddings: List[int], - activation_fn: Type[nn.Module], - ) -> None: - super().__init__() - self.dilation_conv = nn.Sequential( - nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1], - kernel_size=kernel_sizes[0], - stride=1, - dilation=dilations[0], - padding=paddings[0], - ), - nn.Conv2d( - in_channels=channels[1], - out_channels=channels[1] // 2, - kernel_size=kernel_sizes[1], - stride=1, - dilation=dilations[1], - padding=paddings[1], - ), - ) - self.activation_fn = activation_fn - - self.conv = nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1] // 2, - kernel_size=1, - dilation=1, - stride=1, - ) - - def forward(self, x: Tensor) -> Tensor: - residual = self.conv(x) - x = self.dilation_conv(x) - x = torch.cat((x, residual), dim=1) - return self.activation_fn(x) - - -class FCN(nn.Module): - def __init__( - self, - in_channels: int, - base_channels: int, - out_channels: int, - kernel_size: int, - dilations: Tuple[int] = (3, 7), - paddings: Tuple[int] = (9, 21), - num_blocks: int = 14, - activation: str = "elu", - ) -> None: - super().__init__() - self.kernel_sizes = [kernel_size] * num_blocks - self.channels = [in_channels] + [base_channels] * (num_blocks - 1) - self.out_channels = out_channels - self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * ( - num_blocks // 2 - ) - self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * ( - num_blocks // 2 - ) - self.activation_fn = activation_function(activation) - self.fcn = self._configure_fcn() - - def _configure_fcn(self) -> nn.Sequential: - layers = [] - for i in range(0, len(self.channels), 2): - layers.append( - _DilatedBlock( - self.channels[i : i + 2], - self.kernel_sizes[i : i + 2], - self.dilations[i : i + 2], - self.paddings[i : i + 2], - self.activation_fn, - ) - ) - layers.append( - nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1) - ) - return nn.Sequential(*layers) - - def forward(self, x: Tensor) -> Tensor: - return self.fcn(x) diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py deleted file mode 100644 index 7f8c49b..0000000 --- a/src/text_recognizer/networks/neural_machine_reader.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Sequence to sequence network with RNN cells.""" -# from typing import Dict, Optional, Tuple - -# from einops import rearrange -# from einops.layers.torch import Rearrange -# import torch -# from torch import nn -# from torch import Tensor - -# from text_recognizer.networks.util import configure_backbone - - -# class Encoder(nn.Module): -# def __init__( -# self, -# embedding_dim: int, -# encoder_dim: int, -# decoder_dim: int, -# dropout_rate: float = 0.1, -# ) -> None: -# super().__init__() -# self.rnn = nn.GRU( -# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True -# ) -# self.fc = nn.Sequential( -# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh() -# ) -# self.dropout = nn.Dropout(p=dropout_rate) - -# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: -# """Encodes a sequence of tensors with a bidirectional GRU. - -# Args: -# x (Tensor): A input sequence. - -# Shape: -# - x: :math:`(T, N, E)`. -# - output[0]: :math:`(T, N, 2 * E)`. -# - output[1]: :math:`(T, N, D)`. - -# where T is the sequence length, N is the batch size, E is the -# embedding/encoder dimension, and D is the decoder dimension. - -# Returns: -# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the -# encoder. - -# """ - -# output, hidden = self.rnn(x) - -# # Get the hidden state from the forward and backward rnn. -# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) - -# # Apply fully connected layer and tanh activation. -# hidden_state = self.fc(hidden_state) - -# return output, hidden_state - - -# class Attention(nn.Module): -# def __init__(self, encoder_dim: int, decoder_dim: int) -> None: -# super().__init__() -# self.atten = nn.Linear( -# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim -# ) -# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False) - -# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: -# """Short summary. - -# Args: -# hidden_state (Tensor): Description of parameter `h`. -# encoder_outputs (Tensor): Description of parameter `enc_out`. - -# Shape: -# - x: :math:`(T, N, E)`. -# - output[0]: :math:`(T, N, 2 * E)`. -# - output[1]: :math:`(T, N, D)`. - -# where T is the sequence length, N is the batch size, E is the -# embedding/encoder dimension, and D is the decoder dimension. - -# Returns: -# Tensor: Description of returned object. - -# """ -# t, b = enc_out.shape[:2] -# # repeat decoder hidden state src_len times -# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1) - -# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2") - -# # Calculate the energy between the decoders previous hidden state and the -# # encoders hidden states. -# energy = torch.tanh( -# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2)) -# ) - -# attention = self.value(energy).squeeze(2) - -# # Apply softmax on the attention to squeeze it between 0 and 1. -# attention = F.softmax(attention, dim=1) - -# return attention - - -# class Decoder(nn.Module): -# def __init__( -# self, -# embedding_dim: int, -# encoder_dim: int, -# decoder_dim: int, -# output_dim: int, -# dropout_rate: float = 0.1, -# ) -> None: -# super().__init__() -# self.output_dim = output_dim -# self.embedding = nn.Embedding(output_dim, embedding_dim) -# self.attention = Attention(encoder_dim, decoder_dim) -# self.rnn = nn.GRU( -# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim -# ) - -# self.head = nn.Linear( -# in_features=2 * encoder_dim + embedding_dim + decoder_dim, -# out_features=output_dim, -# ) -# self.dropout = nn.Dropout(p=dropout_rate) - -# def forward( -# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor -# ) -> Tensor: -# # input = [batch size] -# # hidden = [batch size, dec hid dim] -# # encoder_outputs = [src len, batch size, enc hid dim * 2] -# trg = trg.unsqueeze(0) -# trg_embedded = self.dropout(self.embedding(trg)) - -# a = self.attention(hidden_state, encoder_outputs) - -# weighted = torch.bmm(a, encoder_outputs) - -# # Permutate the tensor. -# weighted = rearrange(weighted, "b a e2 -> a b e2") - -# rnn_input = torch.cat((trg_embedded, weighted), dim=2) - -# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) - -# # seq len, n layers and n directions will always be 1 in this decoder, therefore: -# # output = [1, batch size, dec hid dim] -# # hidden = [1, batch size, dec hid dim] -# # this also means that output == hidden -# assert (output == hidden).all() - -# trg_embedded = trg_embedded.squeeze(0) -# output = output.squeeze(0) -# weighted = weighted.squeeze(0) - -# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1)) - -# # prediction = [batch size, output dim] - -# return logits, hidden.squeeze(0) - - -# class NeuralMachineReader(nn.Module): -# def __init__( -# self, -# embedding_dim: int, -# encoder_dim: int, -# decoder_dim: int, -# output_dim: int, -# backbone: Optional[str] = None, -# backbone_args: Optional[Dict] = None, -# adaptive_pool_dim: Tuple = (None, 1), -# dropout_rate: float = 0.1, -# teacher_forcing_ratio: float = 0.5, -# ) -> None: -# super().__init__() - -# self.backbone = configure_backbone(backbone, backbone_args) -# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim)) - -# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate) -# self.decoder = Decoder( -# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate -# ) -# self.teacher_forcing_ratio = teacher_forcing_ratio - -# def extract_image_features(self, x: Tensor) -> Tensor: -# x = self.backbone(x) -# x = rearrange(x, "b c h w -> b w c h") -# x = self.adaptive_pool(x) -# x = x.squeeze(3) - -# def forward(self, x: Tensor, trg: Tensor) -> Tensor: -# # x = [batch size, height, width] -# # trg = [trg len, batch size] -# z = self.extract_image_features(x) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 6405192..e397224 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -7,7 +7,6 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.stn import SpatialTransformerNetwork from text_recognizer.networks.util import activation_function @@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module): activation: str = "relu", block: Type[nn.Module] = BasicBlock, levels: int = 1, - stn: bool = False, *args, **kwargs ) -> None: super().__init__() - self.stn = SpatialTransformerNetwork() if stn else None self.block_sizes = ( block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels ) @@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) - if self.stn is not None: - x = self.stn(x) x = self.gate(x) x = self.blocks(x) return x diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py index 51f242a..510910f 100644 --- a/src/text_recognizer/networks/unet.py +++ b/src/text_recognizer/networks/unet.py @@ -8,64 +8,118 @@ from torch import Tensor from text_recognizer.networks.util import activation_function -class ConvBlock(nn.Module): - """Basic UNet convolutional block.""" +class _ConvBlock(nn.Module): + """Modified UNet convolutional block with dilation.""" - def __init__(self, channels: List[int], activation: str) -> None: + def __init__( + self, + channels: List[int], + activation: str, + num_groups: int, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, + ) -> None: super().__init__() self.channels = channels + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.num_groups = num_groups self.activation = activation_function(activation) self.block = self._configure_block() + self.residual_conv = nn.Sequential( + nn.Conv2d( + self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 + ), + self.activation, + ) def _configure_block(self) -> nn.Sequential: block = [] for i in range(len(self.channels) - 1): block += [ + nn.Dropout(p=self.dropout_rate), + nn.GroupNorm(self.num_groups, self.channels[i]), + self.activation, nn.Conv2d( - self.channels[i], self.channels[i + 1], kernel_size=3, padding=1 + self.channels[i], + self.channels[i + 1], + kernel_size=self.kernel_size, + padding=self.padding, + stride=1, + dilation=self.dilation, ), - nn.BatchNorm2d(self.channels[i + 1]), - self.activation, ] return nn.Sequential(*block) def forward(self, x: Tensor) -> Tensor: """Apply the convolutional block.""" - return self.block(x) + residual = self.residual_conv(x) + return self.block(x) + residual -class DownSamplingBlock(nn.Module): +class _DownSamplingBlock(nn.Module): """Basic down sampling block.""" def __init__( self, channels: List[int], activation: str, + num_groups: int, pooling_kernel: Union[int, bool] = 2, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, ) -> None: super().__init__() - self.conv_block = ConvBlock(channels, activation) + self.conv_block = _ConvBlock( + channels, + activation, + num_groups, + dropout_rate, + kernel_size, + dilation, + padding, + ) self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Return the convolutional block output and a down sampled tensor.""" x = self.conv_block(x) - if self.down_sampling is not None: - x_down = self.down_sampling(x) - else: - x_down = None + x_down = self.down_sampling(x) if self.down_sampling is not None else x + return x_down, x -class UpSamplingBlock(nn.Module): +class _UpSamplingBlock(nn.Module): """The upsampling block of the UNet.""" def __init__( - self, channels: List[int], activation: str, scale_factor: int = 2 + self, + channels: List[int], + activation: str, + num_groups: int, + scale_factor: int = 2, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, ) -> None: super().__init__() - self.conv_block = ConvBlock(channels, activation) + self.conv_block = _ConvBlock( + channels, + activation, + num_groups, + dropout_rate, + kernel_size, + dilation, + padding, + ) self.up_sampling = nn.Upsample( scale_factor=scale_factor, mode="bilinear", align_corners=True ) @@ -87,14 +141,43 @@ class UNet(nn.Module): base_channels: int = 64, num_classes: int = 3, depth: int = 4, - out_channels: int = 3, activation: str = "relu", + num_groups: int = 8, + dropout_rate: float = 0.1, pooling_kernel: int = 2, scale_factor: int = 2, + kernel_size: Optional[List[int]] = None, + dilation: Optional[List[int]] = None, + padding: Optional[List[int]] = None, ) -> None: super().__init__() self.depth = depth - channels = [1] + [base_channels * 2 ** i for i in range(depth)] + self.num_groups = num_groups + + if kernel_size is not None and dilation is not None and padding is not None: + if ( + len(kernel_size) != depth + and len(dilation) != depth + and len(padding) != depth + ): + raise RuntimeError( + "Length of convolutional parameters does not match the depth." + ) + self.kernel_size = kernel_size + self.padding = padding + self.dilation = dilation + + else: + self.kernel_size = [3] * depth + self.padding = [1] * depth + self.dilation = [1] * depth + + self.dropout_rate = dropout_rate + self.conv = nn.Conv2d( + in_channels, base_channels, kernel_size=3, stride=1, padding=1 + ) + + channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] self.encoder_blocks = self._configure_down_sampling_blocks( channels, activation, pooling_kernel ) @@ -110,49 +193,63 @@ class UNet(nn.Module): blocks = nn.ModuleList([]) for i in range(len(channels) - 1): pooling_kernel = pooling_kernel if i < self.depth - 1 else False + dropout_rate = self.dropout_rate if i < 0 else 0 blocks += [ - DownSamplingBlock( + _DownSamplingBlock( [channels[i], channels[i + 1], channels[i + 1]], activation, + self.num_groups, pooling_kernel, + dropout_rate, + self.kernel_size[i], + self.dilation[i], + self.padding[i], ) ] return blocks def _configure_up_sampling_blocks( - self, - channels: List[int], - activation: str, - scale_factor: int, + self, channels: List[int], activation: str, scale_factor: int, ) -> nn.ModuleList: channels.reverse() + self.kernel_size.reverse() + self.dilation.reverse() + self.padding.reverse() return nn.ModuleList( [ - UpSamplingBlock( + _UpSamplingBlock( [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], activation, + self.num_groups, scale_factor, + self.dropout_rate, + self.kernel_size[i], + self.dilation[i], + self.padding[i], ) for i in range(len(channels) - 2) ] ) - def encode(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]: + def _encode(self, x: Tensor) -> List[Tensor]: x_skips = [] for block in self.encoder_blocks: x, x_skip = block(x) - if x_skip is not None: - x_skips.append(x_skip) - return x, x_skips + x_skips.append(x_skip) + return x_skips - def decode(self, x: Tensor, x_skips: List[Tensor]) -> Tensor: + def _decode(self, x_skips: List[Tensor]) -> Tensor: x = x_skips[-1] for i, block in enumerate(self.decoder_blocks): x = block(x, x_skips[-(i + 2)]) return x def forward(self, x: Tensor) -> Tensor: - x, x_skips = self.encode(x) - x = self.decode(x, x_skips) + """Forward pass with the UNet model.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + x = self.conv(x) + x_skips = self._encode(x) + x = self._decode(x_skips) return self.head(x) -- cgit v1.2.3-70-g09d2