diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 18:50:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 18:50:55 +0200 |
commit | a2a3133ed5da283888efbdb9924d0e3733c274c8 (patch) | |
tree | f6b49a227b08ff2e1a1c5809a576de6a2061ccf4 | |
parent | 548f52b35062e258622ea638ed1b132d6759a07a (diff) |
tranformer layer done
-rw-r--r-- | notebooks/00-scratch-pad.ipynb | 246 | ||||
-rw-r--r-- | text_recognizer/networks/loss/label_smoothing_loss.py | 42 | ||||
-rw-r--r-- | text_recognizer/networks/loss/loss.py | 39 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 42 |
4 files changed, 252 insertions, 117 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index b6ec2c8..0a5e2f3 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -57,6 +57,181 @@ }, { "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.encoders.efficientnet import EfficientNet" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "en = EfficientNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "├─Sequential: 1-1 [-1, 256, 18, 20] --\n", + "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n", + "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n", + "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n", + "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n", + "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n", + "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n", + "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n", + "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n", + "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n", + "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n", + "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n", + "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n", + "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n", + "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n", + "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n", + "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n", + "==========================================================================================\n", + "Total params: 13,704,252\n", + "Trainable params: 13,704,252\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 1.23\n", + "==========================================================================================\n", + "Input size (MB): 1.41\n", + "Forward/backward pass size (MB): 111.45\n", + "Params size (MB): 52.28\n", + "Estimated Total Size (MB): 165.13\n", + "==========================================================================================\n" + ] + }, + { + "data": { + "text/plain": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "├─Sequential: 1-1 [-1, 256, 18, 20] --\n", + "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n", + "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n", + "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n", + "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n", + "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n", + "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n", + "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n", + "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n", + "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n", + "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n", + "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n", + "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n", + "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n", + "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n", + "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n", + "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n", + "==========================================================================================\n", + "Total params: 13,704,252\n", + "Trainable params: 13,704,252\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 1.23\n", + "==========================================================================================\n", + "Input size (MB): 1.41\n", + "Forward/backward pass size (MB): 111.45\n", + "Params size (MB): 52.28\n", + "Estimated Total Size (MB): 165.13\n", + "==========================================================================================" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary(en, (1, 576, 640))" + ] + }, + { + "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ @@ -409,77 +584,6 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(list(filter(lambda x: x == \"a\", (\"a\", \"c\") * 8)))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ModuleList(\n", - " (0): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (1): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (2): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (3): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (4): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (5): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (6): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (7): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (8): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (9): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])" - ] - }, - { - "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py new file mode 100644 index 0000000..40a7609 --- /dev/null +++ b/text_recognizer/networks/loss/label_smoothing_loss.py @@ -0,0 +1,42 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """Label smoothing cross entropy loss.""" + + def __init__( + self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 + ) -> None: + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super().__init__() + + smoothing_value = label_smoothing / (vocab_size - 2) + one_hot = torch.full((vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): Predictions from the network. + targets (Tensor): Ground truth. + + Shapes: + outpus: Batch size x num classes + targets: Batch size + + Returns: + Tensor: Label smoothing loss. + """ + model_prob = self.one_hot.repeat(targets.size(0), 1) + model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) + model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py deleted file mode 100644 index d12dc9c..0000000 --- a/text_recognizer/networks/loss/loss.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - -__all__ = ["LabelSmoothingCrossEntropy"] - - -class LabelSmoothingCrossEntropy(nn.Module): - """Label smoothing loss function.""" - - def __init__( - self, - classes: int, - smoothing: float = 0.0, - ignore_index: int = None, - dim: int = -1, - ) -> None: - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.ignore_index = ignore_index - self.cls = classes - self.dim = dim - - def forward(self, pred: Tensor, target: Tensor) -> Tensor: - """Calculates the loss.""" - pred = pred.log_softmax(dim=self.dim) - with torch.no_grad(): - # true_dist = pred.data.clone() - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - if self.ignore_index is not None: - true_dist[:, self.ignore_index] = 0 - mask = torch.nonzero(target == self.ignore_index, as_tuple=False) - if mask.dim() > 0: - true_dist.index_fill_(0, mask.squeeze(), 0.0) - return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 1c951ae..a2fdb1a 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -10,6 +10,7 @@ from torch import nn, Tensor from .attention import Attention from .mlp import FeedForward from .residual import Residual +from .rotary_embedding import RotaryEmbedding class AttentionLayers(nn.Module): @@ -24,17 +25,23 @@ class AttentionLayers(nn.Module): norm_fn: Type[nn.Module] = nn.LayerNorm, ff_fn: Type[nn.Module] = FeedForward, residual_fn: Type[nn.Module] = Residual, + rotary_emb: Optional[Type[nn.Module]] = None, + rotary_emb_dim: Optional[int] = None, causal: bool = False, cross_attend: bool = False, + pre_norm: bool = True, ) -> None: super().__init__() attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) norm_fn = partial(norm_fn, dim=dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) - layer_types = self._get_layer_types(cross_attend) * depth + self.layer_types = self._get_layer_types(cross_attend) * depth self.layers = self._build_network( - layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn + causal, attn_fn, norm_fn, ff_fn, residual_fn ) + rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None + self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None + self.pre_norm = pre_norm @staticmethod def _get_layer_types(cross_attend: bool) -> Tuple: @@ -43,18 +50,17 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - @staticmethod def _build_network( - layer_types: Tuple, + self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial, residual_fn: Type[nn.Module], ) -> nn.ModuleList: - """Configures transformer layers.""" + """Configures transformer network.""" layers = nn.ModuleList([]) - for layer_type in layer_types: + for layer_type in self.layer_types: if layer_type == "a": layer = attn_fn(causal=causal) elif layer_type == "c": @@ -74,4 +80,26 @@ class AttentionLayers(nn.Module): mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: - pass + rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None + + for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = i == len(self.layers) - 1 + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == "a": + out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb) + elif layer_type == "c": + out, _ = block(x, context=context, mask=mask, context_mask=context_mask) + elif layer_type == "f": + out = block(x) + + x = residual_fn(out, residual) + + if not self.pre_norm and not is_last: + x = norm(x) + + return x |