summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/networks
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py2
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py58
-rw-r--r--src/text_recognizer/networks/cnn_transformer_encoder.py73
-rw-r--r--src/text_recognizer/networks/crnn.py40
-rw-r--r--src/text_recognizer/networks/ctc.py2
-rw-r--r--src/text_recognizer/networks/densenet.py4
-rw-r--r--src/text_recognizer/networks/loss.py39
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py1
-rw-r--r--src/text_recognizer/networks/transformer/sparse_transformer.py1
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py1
-rw-r--r--src/text_recognizer/networks/util.py4
-rw-r--r--src/text_recognizer/networks/vision_transformer.py19
12 files changed, 197 insertions, 47 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 8b87797..6d88768 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,5 +1,6 @@
"""Network modules."""
from .cnn_transformer import CNNTransformer
+from .cnn_transformer_encoder import CNNTransformerEncoder
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
@@ -15,6 +16,7 @@ from .wide_resnet import WideResidualNetwork
__all__ = [
"CNNTransformer",
+ "CNNTransformerEncoder",
"ConvolutionalRecurrentNetwork",
"DenseNet",
"EmbeddingLoss",
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 8666f11..3da2c9f 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,8 +1,7 @@
"""A DETR style transfomers but for text recognition."""
-from typing import Dict, Optional, Tuple, Type
+from typing import Dict, Optional, Tuple
-from einops.layers.torch import Rearrange
-from loguru import logger
+from einops import rearrange
import torch
from torch import nn
from torch import Tensor
@@ -21,23 +20,32 @@ class CNNTransformer(nn.Module):
hidden_dim: int,
vocab_size: int,
num_heads: int,
- max_len: int,
+ adaptive_pool_dim: Tuple,
expansion_dim: int,
dropout_rate: float,
trg_pad_index: int,
backbone: str,
+ out_channels: int,
+ max_len: int,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
- self.backbone_args = backbone_args
+
self.backbone = configure_backbone(backbone, backbone_args)
self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
- self.collapse_spatial_dim = nn.Sequential(
- Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim))
+
+ # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1)
+
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+ self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
)
+
self.transformer = Transformer(
num_encoder_layers,
num_decoder_layers,
@@ -47,7 +55,8 @@ class CNNTransformer(nn.Module):
dropout_rate,
activation,
)
- self.head = nn.Linear(hidden_dim, vocab_size)
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
def _create_trg_mask(self, trg: Tensor) -> Tensor:
# Move this outside the transformer.
@@ -83,8 +92,22 @@ class CNNTransformer(nn.Module):
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
src = self.backbone(src)
- src = self.collapse_spatial_dim(src)
- src = self.position_encoding(src)
+ # src = self.conv(src)
+ if self.adaptive_pool is not None:
+ src = self.adaptive_pool(src)
+ H, W = src.shape[-2:]
+ src = rearrange(src, "b t h w -> b t (h w)")
+
+ # construct positional encodings
+ pos = torch.cat(
+ [
+ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
+ self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
+ ],
+ dim=-1,
+ ).unsqueeze(0)
+ pos = rearrange(pos, "b h w l -> b l (h w)")
+ src = pos + 0.1 * src
return src
def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
@@ -97,15 +120,16 @@ class CNNTransformer(nn.Module):
Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
"""
- trg_mask = self._create_trg_mask(trg)
trg = self.character_embedding(trg.long())
trg = self.position_encoding(trg)
- return trg, trg_mask
+ return trg
- def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
"""Forward pass with CNN transfomer."""
- src = self.preprocess_input(x)
- trg, trg_mask = self.preprocess_target(trg)
- out = self.transformer(src, trg, trg_mask=trg_mask)
+ h = self.preprocess_input(x)
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.preprocess_target(trg)
+ out = self.transformer(h, trg, trg_mask=trg_mask)
+
logits = self.head(out)
return logits
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
new file mode 100644
index 0000000..93626bf
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer_encoder.py
@@ -0,0 +1,73 @@
+"""Network with a CNN backend and a transformer encoder head."""
+from typing import Dict
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformerEncoder(nn.Module):
+ """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict,
+ mlp_dim: int,
+ d_model: int,
+ nhead: int = 8,
+ dropout_rate: float = 0.1,
+ activation: str = "relu",
+ num_layers: int = 6,
+ num_classes: int = 80,
+ num_channels: int = 256,
+ max_len: int = 97,
+ ) -> None:
+ super().__init__()
+ self.d_model = d_model
+ self.nhead = nhead
+ self.dropout_rate = dropout_rate
+ self.activation = activation
+ self.num_layers = num_layers
+
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.position_encoding = PositionalEncoding(d_model, dropout_rate)
+ self.encoder = self._configure_encoder()
+
+ self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
+
+ self.mlp = nn.Linear(mlp_dim, d_model)
+
+ self.head = nn.Linear(d_model, num_classes)
+
+ def _configure_encoder(self) -> nn.TransformerEncoder:
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=self.d_model,
+ nhead=self.nhead,
+ dropout=self.dropout_rate,
+ activation=self.activation,
+ )
+ norm = nn.LayerNorm(self.d_model)
+ return nn.TransformerEncoder(
+ encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
+ )
+
+ def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
+ """Forward pass through the network."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ x = self.conv(self.backbone(x))
+ x = rearrange(x, "b c h w -> b c (h w)")
+ x = self.mlp(x)
+ x = self.position_encoding(x)
+ x = rearrange(x, "b c h-> c b h")
+ x = self.encoder(x)
+ x = rearrange(x, "c b h-> b c h")
+ logits = self.head(x)
+
+ return logits
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
index 3e605e2..9747429 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -1,12 +1,9 @@
"""LSTM with CTC for handwritten text recognition within a line."""
-import importlib
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Dict, Tuple
from einops import rearrange, reduce
-from einops.layers.torch import Rearrange, Reduce
+from einops.layers.torch import Rearrange
from loguru import logger
-import torch
from torch import nn
from torch import Tensor
@@ -28,16 +25,21 @@ class ConvolutionalRecurrentNetwork(nn.Module):
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
) -> None:
super().__init__()
self.backbone_args = backbone_args or {}
self.patch_size = patch_size
self.stride = stride
- self.sliding_window = self._configure_sliding_window()
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
self.input_size = input_size
self.hidden_size = hidden_size
self.backbone = configure_backbone(backbone, backbone_args)
self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
if recurrent_cell.upper() in ["LSTM", "GRU"]:
recurrent_cell = getattr(nn, recurrent_cell)
@@ -76,15 +78,27 @@ class ConvolutionalRecurrentNetwork(nn.Module):
"""Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
- x = self.sliding_window(x)
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.backbone(x)
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- # Avgerage pooling.
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ x = self.backbone(x)
+
+ # Avgerage pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ b = x.shape[0]
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> c b (h w)", b=b)
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 2493d5c..af9b700 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -33,7 +33,7 @@ def greedy_decoder(
"""
if character_mapper is None:
- character_mapper = EmnistMapper()
+ character_mapper = EmnistMapper(pad_token="_") # noqa: S106
predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
decoded_predictions = []
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
index d2aad60..7dc58d9 100644
--- a/src/text_recognizer/networks/densenet.py
+++ b/src/text_recognizer/networks/densenet.py
@@ -72,7 +72,7 @@ class _DenseBlock(nn.Module):
) -> None:
super().__init__()
self.dense_block = self._build_dense_blocks(
- num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
)
def _build_dense_blocks(
@@ -219,7 +219,7 @@ class DenseNet(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of Densenet."""
- # If batch dimenstion is missing, it needs to be added.
+ # If batch dimenstion is missing, it will be added.
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
return self.densenet(x)
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py
index ff843cf..cf9fa0d 100644
--- a/src/text_recognizer/networks/loss.py
+++ b/src/text_recognizer/networks/loss.py
@@ -1,10 +1,12 @@
"""Implementations of custom loss functions."""
from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
from torch import nn
from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
-
-__all__ = ["EmbeddingLoss"]
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
class EmbeddingLoss:
@@ -32,3 +34,36 @@ class EmbeddingLoss:
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_fn(embeddings, labels, hard_pairs)
return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+ """Label smoothing loss function."""
+
+ def __init__(
+ self,
+ classes: int,
+ smoothing: float = 0.0,
+ ignore_index: int = None,
+ dim: int = -1,
+ ) -> None:
+ super().__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.ignore_index = ignore_index
+ self.cls = classes
+ self.dim = dim
+
+ def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+ """Calculates the loss."""
+ pred = pred.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ # true_dist = pred.data.clone()
+ true_dist = torch.zeros_like(pred)
+ true_dist.fill_(self.smoothing / (self.cls - 1))
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+ if self.ignore_index is not None:
+ true_dist[:, self.ignore_index] = 0
+ mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+ if mask.dim() > 0:
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
index a47141b..1ba5537 100644
--- a/src/text_recognizer/networks/transformer/positional_encoding.py
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -13,6 +13,7 @@ class PositionalEncoding(nn.Module):
) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len).unsqueeze(1)
diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py
deleted file mode 100644
index 8c391c8..0000000
--- a/src/text_recognizer/networks/transformer/sparse_transformer.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Encoder and Decoder modules using spares activations."""
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index 1c9c7dd..c6e943e 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -230,6 +230,7 @@ class Transformer(nn.Module):
) -> Tensor:
"""Forward pass through the transformer."""
if src.shape[0] != trg.shape[0]:
+ print(trg.shape)
raise RuntimeError("The batch size of the src and trg must be the same.")
if src.shape[2] != trg.shape[2]:
raise RuntimeError(
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index 0d08506..b31e640 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -28,7 +28,7 @@ def sliding_window(
c = images.shape[1]
patches = unfold(images)
patches = rearrange(
- patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
)
return patches
@@ -77,7 +77,7 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
backbone = nn.Sequential(
- *list(backbone.children())[0][: -backbone_args["remove_layers"]]
+ *list(backbone.children())[:][: -backbone_args["remove_layers"]]
)
return backbone
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
index 4d204d3..f227954 100644
--- a/src/text_recognizer/networks/vision_transformer.py
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -29,9 +29,9 @@ class VisionTransformer(nn.Module):
num_heads: int,
max_len: int,
expansion_dim: int,
- mlp_dim: int,
dropout_rate: float,
trg_pad_index: int,
+ mlp_dim: Optional[int] = None,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
activation: str = "gelu",
@@ -46,6 +46,7 @@ class VisionTransformer(nn.Module):
self.slidning_window = self._configure_sliding_window()
self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+ self.mlp_dim = mlp_dim
self.use_backbone = False
if backbone is None:
@@ -54,6 +55,8 @@ class VisionTransformer(nn.Module):
)
else:
self.backbone = configure_backbone(backbone, backbone_args)
+ if mlp_dim:
+ self.mlp = nn.Linear(mlp_dim, hidden_dim)
self.use_backbone = True
self.transformer = Transformer(
@@ -66,13 +69,7 @@ class VisionTransformer(nn.Module):
activation,
)
- self.head = nn.Sequential(
- nn.LayerNorm(hidden_dim),
- nn.Linear(hidden_dim, mlp_dim),
- nn.GELU(),
- nn.Dropout(p=dropout_rate),
- nn.Linear(mlp_dim, vocab_size),
- )
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
@@ -110,7 +107,11 @@ class VisionTransformer(nn.Module):
if self.use_backbone:
x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
x = self.backbone(x)
- x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+ if self.mlp_dim:
+ x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t)
+ x = self.mlp(x)
+ else:
+ x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
else:
x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
x = self.linear_projection(x)