From b3fbfd72a8f647161685b28d20b4b61519d8a643 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:49:51 +0200 Subject: Update transformer model --- text_recognizer/network/transformer/attend.py | 11 +++- text_recognizer/network/transformer/attention.py | 19 ++++--- text_recognizer/network/transformer/decoder.py | 8 ++- text_recognizer/network/transformer/encoder.py | 4 ++ text_recognizer/network/transformer/norm.py | 2 +- text_recognizer/network/transformer/swiglu.py | 2 +- text_recognizer/network/transformer/transformer.py | 48 ++++++++++++++++ text_recognizer/network/transformer/vit.py | 64 ++++++++++++++++++++++ 8 files changed, 143 insertions(+), 15 deletions(-) create mode 100644 text_recognizer/network/transformer/transformer.py create mode 100644 text_recognizer/network/transformer/vit.py (limited to 'text_recognizer/network/transformer') diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index a5c23c6..23a6487 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -1,10 +1,10 @@ -from typing import Optional from collections import namedtuple +from typing import Optional import torch -from torch import Tensor, einsum, nn -from einops import rearrange import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, einsum, nn Config = namedtuple( "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] @@ -79,6 +79,11 @@ class Attend(nn.Module): causal: bool, mask: Optional[Tensor] = None, ) -> Tensor: + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + if mask is not None: mask = rearrange(mask, "b j -> b 1 1 j") if self.use_flash: diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py index bae077f..9a4aa0d 100644 --- a/text_recognizer/network/transformer/attention.py +++ b/text_recognizer/network/transformer/attention.py @@ -1,11 +1,12 @@ """Implements the attention module for the transformer.""" from typing import Optional -from einops import rearrange -from text_recognizer.network.transformer.swiglu import SwiGLU import torch +from einops import rearrange from torch import Tensor, nn +from text_recognizer.network.transformer.swiglu import SwiGLU + from .attend import Attend from .embedding.rotary import RotaryEmbedding, apply_rotary_pos_emb @@ -23,7 +24,8 @@ class Attention(nn.Module): dropout_rate: float = 0.0, use_flash: bool = True, norm_context: bool = False, - rotary_emb: Optional[RotaryEmbedding] = None, + use_rotary_emb: bool = False, + one_kv_head: bool = False, ) -> None: super().__init__() self.heads = heads @@ -36,12 +38,13 @@ class Attention(nn.Module): self.norm = nn.LayerNorm(dim) self.context_norm = nn.LayerNorm(dim) if norm_context else nn.Identity() self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, 2 * inner_dim, bias=False) + self.kv_heads = 1 if one_kv_head else heads + self.to_kv = nn.Linear(dim, 2 * self.kv_heads * dim_head, bias=False) self.attend = Attend(use_flash) self.to_out = nn.Linear(inner_dim, dim, bias=False) - self.rotary_emb = rotary_emb + self.rotary_emb = RotaryEmbedding(dim_head) if use_rotary_emb else None self.pos_emb = None ff_inner_dim = ff_mult * dim @@ -68,9 +71,9 @@ class Attention(nn.Module): k, v = self.to_kv(x if context is None else self.context_norm(context)).chunk( 2, dim=-1 ) - - q, k, v = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) + k, v = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.kv_heads), (k, v) ) if self.rotary_emb is not None: diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py index 4ebdd2c..60e426a 100644 --- a/text_recognizer/network/transformer/decoder.py +++ b/text_recognizer/network/transformer/decoder.py @@ -1,9 +1,9 @@ """Transformer decoder module.""" from typing import Optional + from torch import Tensor, nn from .attention import Attention -from .embedding.rotary import RotaryEmbedding class Decoder(nn.Module): @@ -15,6 +15,7 @@ class Decoder(nn.Module): dim_head: int, depth: int, dropout_rate: float = 0.0, + one_kv_head: bool = False, ) -> None: super().__init__() self.norm = nn.LayerNorm(dim) @@ -31,7 +32,8 @@ class Decoder(nn.Module): dropout_rate=dropout_rate, use_flash=True, norm_context=False, - rotary_emb=RotaryEmbedding(dim_head), + use_rotary_emb=True, + one_kv_head=one_kv_head, ), Attention( dim=dim, @@ -42,6 +44,8 @@ class Decoder(nn.Module): dropout_rate=dropout_rate, use_flash=True, norm_context=False, + use_rotary_emb=False, + one_kv_head=one_kv_head, ), ] ) diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 1728c61..ce30372 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -13,6 +13,8 @@ class Encoder(nn.Module): ff_mult: int, depth: int, dropout_rate: float = 0.0, + use_rotary_emb: bool = False, + one_kv_head: bool = False, ) -> None: super().__init__() self.norm = nn.LayerNorm(dim) @@ -27,6 +29,8 @@ class Encoder(nn.Module): dropout_rate=dropout_rate, use_flash=True, norm_context=False, + use_rotary_emb=use_rotary_emb, + one_kv_head=one_kv_head, ) for _ in range(depth) ] diff --git a/text_recognizer/network/transformer/norm.py b/text_recognizer/network/transformer/norm.py index 2737754..9ba35a8 100644 --- a/text_recognizer/network/transformer/norm.py +++ b/text_recognizer/network/transformer/norm.py @@ -5,8 +5,8 @@ Copied from lucidrains: """ import torch -from torch import Tensor, nn import torch.nn.functional as F +from torch import Tensor, nn class RMSNorm(nn.Module): diff --git a/text_recognizer/network/transformer/swiglu.py b/text_recognizer/network/transformer/swiglu.py index e61662a..7bafd06 100644 --- a/text_recognizer/network/transformer/swiglu.py +++ b/text_recognizer/network/transformer/swiglu.py @@ -1,5 +1,5 @@ -from torch import nn import torch.nn.functional as F +from torch import nn class SwiGLU(nn.Module): diff --git a/text_recognizer/network/transformer/transformer.py b/text_recognizer/network/transformer/transformer.py new file mode 100644 index 0000000..298308e --- /dev/null +++ b/text_recognizer/network/transformer/transformer.py @@ -0,0 +1,48 @@ +from torch import Tensor, nn + +from .decoder import Decoder +from .embedding.token import TokenEmbedding +from .vit import Vit + + +class Transformer(nn.Module): + def __init__( + self, + dim: int, + num_classes: int, + encoder: Vit, + decoder: Decoder, + token_embedding: TokenEmbedding, + tie_embeddings: bool, + pad_index: int, + ) -> None: + super().__init__() + self.token_embedding = token_embedding + self.to_logits = ( + nn.Linear(dim, num_classes) + if not tie_embeddings + else lambda t: t @ self.token_embedding.to_embedding.weight.t() + ) + self.encoder = encoder + self.decoder = decoder + self.pad_index = pad_index + + def encode(self, images: Tensor) -> Tensor: + return self.encoder(images) + + def decode(self, text: Tensor, img_features: Tensor) -> Tensor: + text = text.long() + mask = text != self.pad_index + tokens = self.token_embedding(text) + output = self.decoder(tokens, context=img_features, mask=mask) + return self.to_logits(output) + + def forward( + self, + img: Tensor, + text: Tensor, + ) -> Tensor: + """Applies decoder block on input signals.""" + img_features = self.encode(img) + logits = self.decode(text, img_features) + return logits # [B, N, C] diff --git a/text_recognizer/network/transformer/vit.py b/text_recognizer/network/transformer/vit.py new file mode 100644 index 0000000..3b600c3 --- /dev/null +++ b/text_recognizer/network/transformer/vit.py @@ -0,0 +1,64 @@ +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import Tensor, nn + +from .embedding.sincos import sincos_2d +from .encoder import Encoder + + +class PatchDropout(nn.Module): + def __init__(self, prob): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + b, n, _, device = *x.shape, x.device + + batch_indices = torch.arange(b, device = device) + batch_indices = rearrange(batch_indices, '... -> ... 1') + num_patches_keep = max(1, int(n * (1 - self.prob))) + patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices + + return x[batch_indices, patch_indices_keep] + + +class Vit(nn.Module): + def __init__( + self, + image_height: int, + image_width: int, + patch_height: int, + patch_width: int, + dim: int, + encoder: Encoder, + channels: int = 1, + patch_dropout: float = 0.0, + ) -> None: + super().__init__() + patch_dim = patch_height * patch_width * channels + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h ph) (w pw) -> b (h w) (ph pw c)", + ph=patch_height, + pw=patch_width, + ), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + self.patch_embedding = sincos_2d( + h=image_height // patch_height, w=image_width // patch_width, dim=dim + ) + self.encoder = encoder + self.patch_dropout = PatchDropout(patch_dropout) + + def forward(self, images: Tensor) -> Tensor: + x = self.to_patch_embedding(images) + x = x + self.patch_embedding.to(images.device, dtype=images.dtype) + x = self.patch_dropout(x) + return self.encoder(x) -- cgit v1.2.3-70-g09d2