summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 18:50:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 18:50:55 +0200
commita2a3133ed5da283888efbdb9924d0e3733c274c8 (patch)
treef6b49a227b08ff2e1a1c5809a576de6a2061ccf4
parent548f52b35062e258622ea638ed1b132d6759a07a (diff)
tranformer layer done
-rw-r--r--notebooks/00-scratch-pad.ipynb246
-rw-r--r--text_recognizer/networks/loss/label_smoothing_loss.py42
-rw-r--r--text_recognizer/networks/loss/loss.py39
-rw-r--r--text_recognizer/networks/transformer/layers.py42
4 files changed, 252 insertions, 117 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index b6ec2c8..0a5e2f3 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -57,6 +57,181 @@
},
{
"cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.encoders.efficientnet import EfficientNet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "en = EfficientNet()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==========================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "==========================================================================================\n",
+ "├─Sequential: 1-1 [-1, 256, 18, 20] --\n",
+ "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n",
+ "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n",
+ "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n",
+ "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n",
+ "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n",
+ "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n",
+ "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n",
+ "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n",
+ "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n",
+ "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n",
+ "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n",
+ "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n",
+ "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n",
+ "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n",
+ "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n",
+ "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n",
+ "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n",
+ "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n",
+ "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n",
+ "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n",
+ "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n",
+ "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n",
+ "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n",
+ "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n",
+ "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n",
+ "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n",
+ "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n",
+ "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n",
+ "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n",
+ "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n",
+ "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n",
+ "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n",
+ "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n",
+ "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n",
+ "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n",
+ "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n",
+ "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n",
+ "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n",
+ "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n",
+ "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n",
+ "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n",
+ "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n",
+ "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n",
+ "==========================================================================================\n",
+ "Total params: 13,704,252\n",
+ "Trainable params: 13,704,252\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 1.23\n",
+ "==========================================================================================\n",
+ "Input size (MB): 1.41\n",
+ "Forward/backward pass size (MB): 111.45\n",
+ "Params size (MB): 52.28\n",
+ "Estimated Total Size (MB): 165.13\n",
+ "==========================================================================================\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "==========================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "==========================================================================================\n",
+ "├─Sequential: 1-1 [-1, 256, 18, 20] --\n",
+ "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n",
+ "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n",
+ "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n",
+ "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n",
+ "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n",
+ "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n",
+ "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n",
+ "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n",
+ "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n",
+ "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n",
+ "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n",
+ "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n",
+ "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n",
+ "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n",
+ "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n",
+ "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n",
+ "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n",
+ "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n",
+ "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n",
+ "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n",
+ "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n",
+ "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n",
+ "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n",
+ "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n",
+ "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n",
+ "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n",
+ "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n",
+ "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n",
+ "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n",
+ "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n",
+ "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n",
+ "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n",
+ "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n",
+ "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n",
+ "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n",
+ "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n",
+ "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n",
+ "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n",
+ "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n",
+ "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n",
+ "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n",
+ "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n",
+ "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n",
+ "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n",
+ "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n",
+ "==========================================================================================\n",
+ "Total params: 13,704,252\n",
+ "Trainable params: 13,704,252\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 1.23\n",
+ "==========================================================================================\n",
+ "Input size (MB): 1.41\n",
+ "Forward/backward pass size (MB): 111.45\n",
+ "Params size (MB): 52.28\n",
+ "Estimated Total Size (MB): 165.13\n",
+ "=========================================================================================="
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(en, (1, 576, 640))"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
@@ -409,77 +584,6 @@
},
{
"cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "8"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "len(list(filter(lambda x: x == \"a\", (\"a\", \"c\") * 8)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "ModuleList(\n",
- " (0): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (1): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (2): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (3): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (4): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (5): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (6): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (7): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (8): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- " (9): ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])"
- ]
- },
- {
- "cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py
new file mode 100644
index 0000000..40a7609
--- /dev/null
+++ b/text_recognizer/networks/loss/label_smoothing_loss.py
@@ -0,0 +1,42 @@
+"""Implementations of custom loss functions."""
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class LabelSmoothingLoss(nn.Module):
+ """Label smoothing cross entropy loss."""
+
+ def __init__(
+ self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
+ ) -> None:
+ assert 0.0 < label_smoothing <= 1.0
+ self.ignore_index = ignore_index
+ super().__init__()
+
+ smoothing_value = label_smoothing / (vocab_size - 2)
+ one_hot = torch.full((vocab_size,), smoothing_value)
+ one_hot[self.ignore_index] = 0
+ self.register_buffer("one_hot", one_hot.unsqueeze(0))
+
+ self.confidence = 1.0 - label_smoothing
+
+ def forward(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the loss.
+
+ Args:
+ output (Tensor): Predictions from the network.
+ targets (Tensor): Ground truth.
+
+ Shapes:
+ outpus: Batch size x num classes
+ targets: Batch size
+
+ Returns:
+ Tensor: Label smoothing loss.
+ """
+ model_prob = self.one_hot.repeat(targets.size(0), 1)
+ model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
+ model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
+ return F.kl_div(output, model_prob, reduction="sum")
diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py
deleted file mode 100644
index d12dc9c..0000000
--- a/text_recognizer/networks/loss/loss.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""Implementations of custom loss functions."""
-import torch
-from torch import nn
-from torch import Tensor
-
-__all__ = ["LabelSmoothingCrossEntropy"]
-
-
-class LabelSmoothingCrossEntropy(nn.Module):
- """Label smoothing loss function."""
-
- def __init__(
- self,
- classes: int,
- smoothing: float = 0.0,
- ignore_index: int = None,
- dim: int = -1,
- ) -> None:
- super().__init__()
- self.confidence = 1.0 - smoothing
- self.smoothing = smoothing
- self.ignore_index = ignore_index
- self.cls = classes
- self.dim = dim
-
- def forward(self, pred: Tensor, target: Tensor) -> Tensor:
- """Calculates the loss."""
- pred = pred.log_softmax(dim=self.dim)
- with torch.no_grad():
- # true_dist = pred.data.clone()
- true_dist = torch.zeros_like(pred)
- true_dist.fill_(self.smoothing / (self.cls - 1))
- true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
- if self.ignore_index is not None:
- true_dist[:, self.ignore_index] = 0
- mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
- if mask.dim() > 0:
- true_dist.index_fill_(0, mask.squeeze(), 0.0)
- return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 1c951ae..a2fdb1a 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -10,6 +10,7 @@ from torch import nn, Tensor
from .attention import Attention
from .mlp import FeedForward
from .residual import Residual
+from .rotary_embedding import RotaryEmbedding
class AttentionLayers(nn.Module):
@@ -24,17 +25,23 @@ class AttentionLayers(nn.Module):
norm_fn: Type[nn.Module] = nn.LayerNorm,
ff_fn: Type[nn.Module] = FeedForward,
residual_fn: Type[nn.Module] = Residual,
+ rotary_emb: Optional[Type[nn.Module]] = None,
+ rotary_emb_dim: Optional[int] = None,
causal: bool = False,
cross_attend: bool = False,
+ pre_norm: bool = True,
) -> None:
super().__init__()
attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
norm_fn = partial(norm_fn, dim=dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
- layer_types = self._get_layer_types(cross_attend) * depth
+ self.layer_types = self._get_layer_types(cross_attend) * depth
self.layers = self._build_network(
- layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn
+ causal, attn_fn, norm_fn, ff_fn, residual_fn
)
+ rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
+ self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
+ self.pre_norm = pre_norm
@staticmethod
def _get_layer_types(cross_attend: bool) -> Tuple:
@@ -43,18 +50,17 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- @staticmethod
def _build_network(
- layer_types: Tuple,
+ self,
causal: bool,
attn_fn: partial,
norm_fn: partial,
ff_fn: partial,
residual_fn: Type[nn.Module],
) -> nn.ModuleList:
- """Configures transformer layers."""
+ """Configures transformer network."""
layers = nn.ModuleList([])
- for layer_type in layer_types:
+ for layer_type in self.layer_types:
if layer_type == "a":
layer = attn_fn(causal=causal)
elif layer_type == "c":
@@ -74,4 +80,26 @@ class AttentionLayers(nn.Module):
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
- pass
+ rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
+
+ for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = i == len(self.layers) - 1
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == "a":
+ out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
+ elif layer_type == "c":
+ out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
+ elif layer_type == "f":
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ return x