From a2a3133ed5da283888efbdb9924d0e3733c274c8 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 18:50:55 +0200 Subject: tranformer layer done --- .../networks/loss/label_smoothing_loss.py | 42 ++++++++++++++++++++++ text_recognizer/networks/loss/loss.py | 39 -------------------- text_recognizer/networks/transformer/layers.py | 42 ++++++++++++++++++---- 3 files changed, 77 insertions(+), 46 deletions(-) create mode 100644 text_recognizer/networks/loss/label_smoothing_loss.py delete mode 100644 text_recognizer/networks/loss/loss.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py new file mode 100644 index 0000000..40a7609 --- /dev/null +++ b/text_recognizer/networks/loss/label_smoothing_loss.py @@ -0,0 +1,42 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """Label smoothing cross entropy loss.""" + + def __init__( + self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 + ) -> None: + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super().__init__() + + smoothing_value = label_smoothing / (vocab_size - 2) + one_hot = torch.full((vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): Predictions from the network. + targets (Tensor): Ground truth. + + Shapes: + outpus: Batch size x num classes + targets: Batch size + + Returns: + Tensor: Label smoothing loss. + """ + model_prob = self.one_hot.repeat(targets.size(0), 1) + model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) + model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py deleted file mode 100644 index d12dc9c..0000000 --- a/text_recognizer/networks/loss/loss.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - -__all__ = ["LabelSmoothingCrossEntropy"] - - -class LabelSmoothingCrossEntropy(nn.Module): - """Label smoothing loss function.""" - - def __init__( - self, - classes: int, - smoothing: float = 0.0, - ignore_index: int = None, - dim: int = -1, - ) -> None: - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.ignore_index = ignore_index - self.cls = classes - self.dim = dim - - def forward(self, pred: Tensor, target: Tensor) -> Tensor: - """Calculates the loss.""" - pred = pred.log_softmax(dim=self.dim) - with torch.no_grad(): - # true_dist = pred.data.clone() - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - if self.ignore_index is not None: - true_dist[:, self.ignore_index] = 0 - mask = torch.nonzero(target == self.ignore_index, as_tuple=False) - if mask.dim() > 0: - true_dist.index_fill_(0, mask.squeeze(), 0.0) - return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 1c951ae..a2fdb1a 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -10,6 +10,7 @@ from torch import nn, Tensor from .attention import Attention from .mlp import FeedForward from .residual import Residual +from .rotary_embedding import RotaryEmbedding class AttentionLayers(nn.Module): @@ -24,17 +25,23 @@ class AttentionLayers(nn.Module): norm_fn: Type[nn.Module] = nn.LayerNorm, ff_fn: Type[nn.Module] = FeedForward, residual_fn: Type[nn.Module] = Residual, + rotary_emb: Optional[Type[nn.Module]] = None, + rotary_emb_dim: Optional[int] = None, causal: bool = False, cross_attend: bool = False, + pre_norm: bool = True, ) -> None: super().__init__() attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) norm_fn = partial(norm_fn, dim=dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) - layer_types = self._get_layer_types(cross_attend) * depth + self.layer_types = self._get_layer_types(cross_attend) * depth self.layers = self._build_network( - layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn + causal, attn_fn, norm_fn, ff_fn, residual_fn ) + rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None + self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None + self.pre_norm = pre_norm @staticmethod def _get_layer_types(cross_attend: bool) -> Tuple: @@ -43,18 +50,17 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - @staticmethod def _build_network( - layer_types: Tuple, + self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial, residual_fn: Type[nn.Module], ) -> nn.ModuleList: - """Configures transformer layers.""" + """Configures transformer network.""" layers = nn.ModuleList([]) - for layer_type in layer_types: + for layer_type in self.layer_types: if layer_type == "a": layer = attn_fn(causal=causal) elif layer_type == "c": @@ -74,4 +80,26 @@ class AttentionLayers(nn.Module): mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: - pass + rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None + + for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = i == len(self.layers) - 1 + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == "a": + out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb) + elif layer_type == "c": + out, _ = block(x, context=context, mask=mask, context_mask=context_mask) + elif layer_type == "f": + out = block(x) + + x = residual_fn(out, residual) + + if not self.pre_norm and not is_last: + x = norm(x) + + return x -- cgit v1.2.3-70-g09d2