diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/network/transformer/attend.py | 11 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/attention.py | 19 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/decoder.py | 8 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/encoder.py | 4 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/norm.py | 2 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/swiglu.py | 2 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/transformer.py | 48 | ||||
| -rw-r--r-- | text_recognizer/network/transformer/vit.py (renamed from text_recognizer/network/vit.py) | 29 | 
8 files changed, 106 insertions, 17 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index a5c23c6..23a6487 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -1,10 +1,10 @@ -from typing import Optional  from collections import namedtuple +from typing import Optional  import torch -from torch import Tensor, einsum, nn -from einops import rearrange  import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, einsum, nn  Config = namedtuple(      "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] @@ -79,6 +79,11 @@ class Attend(nn.Module):          causal: bool,          mask: Optional[Tensor] = None,      ) -> Tensor: +        if k.ndim == 3: +            k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) +        if v.ndim == 3: +            v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) +          if mask is not None:              mask = rearrange(mask, "b j -> b 1 1 j")          if self.use_flash: diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py index bae077f..9a4aa0d 100644 --- a/text_recognizer/network/transformer/attention.py +++ b/text_recognizer/network/transformer/attention.py @@ -1,11 +1,12 @@  """Implements the attention module for the transformer."""  from typing import Optional -from einops import rearrange -from text_recognizer.network.transformer.swiglu import SwiGLU  import torch +from einops import rearrange  from torch import Tensor, nn +from text_recognizer.network.transformer.swiglu import SwiGLU +  from .attend import Attend  from .embedding.rotary import RotaryEmbedding, apply_rotary_pos_emb @@ -23,7 +24,8 @@ class Attention(nn.Module):          dropout_rate: float = 0.0,          use_flash: bool = True,          norm_context: bool = False, -        rotary_emb: Optional[RotaryEmbedding] = None, +        use_rotary_emb: bool = False, +        one_kv_head: bool = False,      ) -> None:          super().__init__()          self.heads = heads @@ -36,12 +38,13 @@ class Attention(nn.Module):          self.norm = nn.LayerNorm(dim)          self.context_norm = nn.LayerNorm(dim) if norm_context else nn.Identity()          self.to_q = nn.Linear(dim, inner_dim, bias=False) -        self.to_kv = nn.Linear(dim, 2 * inner_dim, bias=False) +        self.kv_heads = 1 if one_kv_head else heads +        self.to_kv = nn.Linear(dim, 2 * self.kv_heads * dim_head, bias=False)          self.attend = Attend(use_flash)          self.to_out = nn.Linear(inner_dim, dim, bias=False) -        self.rotary_emb = rotary_emb +        self.rotary_emb = RotaryEmbedding(dim_head) if use_rotary_emb else None          self.pos_emb = None          ff_inner_dim = ff_mult * dim @@ -68,9 +71,9 @@ class Attention(nn.Module):          k, v = self.to_kv(x if context is None else self.context_norm(context)).chunk(              2, dim=-1          ) - -        q, k, v = map( -            lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) +        q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) +        k, v = map( +            lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.kv_heads), (k, v)          )          if self.rotary_emb is not None: diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py index 4ebdd2c..60e426a 100644 --- a/text_recognizer/network/transformer/decoder.py +++ b/text_recognizer/network/transformer/decoder.py @@ -1,9 +1,9 @@  """Transformer decoder module."""  from typing import Optional +  from torch import Tensor, nn  from .attention import Attention -from .embedding.rotary import RotaryEmbedding  class Decoder(nn.Module): @@ -15,6 +15,7 @@ class Decoder(nn.Module):          dim_head: int,          depth: int,          dropout_rate: float = 0.0, +        one_kv_head: bool = False,      ) -> None:          super().__init__()          self.norm = nn.LayerNorm(dim) @@ -31,7 +32,8 @@ class Decoder(nn.Module):                              dropout_rate=dropout_rate,                              use_flash=True,                              norm_context=False, -                            rotary_emb=RotaryEmbedding(dim_head), +                            use_rotary_emb=True, +                            one_kv_head=one_kv_head,                          ),                          Attention(                              dim=dim, @@ -42,6 +44,8 @@ class Decoder(nn.Module):                              dropout_rate=dropout_rate,                              use_flash=True,                              norm_context=False, +                            use_rotary_emb=False, +                            one_kv_head=one_kv_head,                          ),                      ]                  ) diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 1728c61..ce30372 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -13,6 +13,8 @@ class Encoder(nn.Module):          ff_mult: int,          depth: int,          dropout_rate: float = 0.0, +        use_rotary_emb: bool = False, +        one_kv_head: bool = False,      ) -> None:          super().__init__()          self.norm = nn.LayerNorm(dim) @@ -27,6 +29,8 @@ class Encoder(nn.Module):                      dropout_rate=dropout_rate,                      use_flash=True,                      norm_context=False, +                    use_rotary_emb=use_rotary_emb, +                    one_kv_head=one_kv_head,                  )                  for _ in range(depth)              ] diff --git a/text_recognizer/network/transformer/norm.py b/text_recognizer/network/transformer/norm.py index 2737754..9ba35a8 100644 --- a/text_recognizer/network/transformer/norm.py +++ b/text_recognizer/network/transformer/norm.py @@ -5,8 +5,8 @@ Copied from lucidrains:  """  import torch -from torch import Tensor, nn  import torch.nn.functional as F +from torch import Tensor, nn  class RMSNorm(nn.Module): diff --git a/text_recognizer/network/transformer/swiglu.py b/text_recognizer/network/transformer/swiglu.py index e61662a..7bafd06 100644 --- a/text_recognizer/network/transformer/swiglu.py +++ b/text_recognizer/network/transformer/swiglu.py @@ -1,5 +1,5 @@ -from torch import nn  import torch.nn.functional as F +from torch import nn  class SwiGLU(nn.Module): diff --git a/text_recognizer/network/transformer/transformer.py b/text_recognizer/network/transformer/transformer.py new file mode 100644 index 0000000..298308e --- /dev/null +++ b/text_recognizer/network/transformer/transformer.py @@ -0,0 +1,48 @@ +from torch import Tensor, nn + +from .decoder import Decoder +from .embedding.token import TokenEmbedding +from .vit import Vit + + +class Transformer(nn.Module): +    def __init__( +        self, +        dim: int, +        num_classes: int, +        encoder: Vit, +        decoder: Decoder, +        token_embedding: TokenEmbedding, +        tie_embeddings: bool, +        pad_index: int, +    ) -> None: +        super().__init__() +        self.token_embedding = token_embedding +        self.to_logits = ( +            nn.Linear(dim, num_classes) +            if not tie_embeddings +            else lambda t: t @ self.token_embedding.to_embedding.weight.t() +        ) +        self.encoder = encoder +        self.decoder = decoder +        self.pad_index = pad_index + +    def encode(self, images: Tensor) -> Tensor: +        return self.encoder(images) + +    def decode(self, text: Tensor, img_features: Tensor) -> Tensor: +        text = text.long() +        mask = text != self.pad_index +        tokens = self.token_embedding(text) +        output = self.decoder(tokens, context=img_features, mask=mask) +        return self.to_logits(output) + +    def forward( +        self, +        img: Tensor, +        text: Tensor, +    ) -> Tensor: +        """Applies decoder block on input signals.""" +        img_features = self.encode(img) +        logits = self.decode(text, img_features) +        return logits  # [B, N, C] diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/transformer/vit.py index a596792..3b600c3 100644 --- a/text_recognizer/network/vit.py +++ b/text_recognizer/network/transformer/vit.py @@ -1,8 +1,30 @@ +import torch +from einops import rearrange  from einops.layers.torch import Rearrange  from torch import Tensor, nn -from .transformer.embedding.sincos import sincos_2d -from .transformer.encoder import Encoder +from .embedding.sincos import sincos_2d +from .encoder import Encoder + + +class PatchDropout(nn.Module): +    def __init__(self, prob): +        super().__init__() +        assert 0 <= prob < 1. +        self.prob = prob + +    def forward(self, x): +        if not self.training or self.prob == 0.: +            return x + +        b, n, _, device = *x.shape, x.device + +        batch_indices = torch.arange(b, device = device) +        batch_indices = rearrange(batch_indices, '... -> ... 1') +        num_patches_keep = max(1, int(n * (1 - self.prob))) +        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices + +        return x[batch_indices, patch_indices_keep]  class Vit(nn.Module): @@ -15,6 +37,7 @@ class Vit(nn.Module):          dim: int,          encoder: Encoder,          channels: int = 1, +        patch_dropout: float = 0.0,      ) -> None:          super().__init__()          patch_dim = patch_height * patch_width * channels @@ -32,8 +55,10 @@ class Vit(nn.Module):              h=image_height // patch_height, w=image_width // patch_width, dim=dim          )          self.encoder = encoder +        self.patch_dropout = PatchDropout(patch_dropout)      def forward(self, images: Tensor) -> Tensor:          x = self.to_patch_embedding(images)          x = x + self.patch_embedding.to(images.device, dtype=images.dtype) +        x = self.patch_dropout(x)          return self.encoder(x)  |