diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
| -rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 6 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/rotary_embedding.py | 39 | 
2 files changed, 44 insertions, 1 deletions
| diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index 5874e97..c50afc3 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module):      def forward(self, x: Tensor) -> Tensor:          """Encodes the tensor with a postional embedding.""" -        x = x + self.pe[:, : x.shape[1]] +        # [T, B, D] +        if x.shape[2] != self.pe.shape[2]: +            raise ValueError(f"x shape does not match pe in the 3rd dim.") +        x = x + self.pe[: x.shape[0]]          return self.dropout(x) @@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module):          pe = self.make_pe(hidden_dim, max_h, max_w)          self.register_buffer("pe", pe) +    @staticmethod      def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:          """Returns 2d postional encoding."""          pe_h = PositionalEncoding.make_pe( diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py new file mode 100644 index 0000000..5e80572 --- /dev/null +++ b/text_recognizer/networks/transformer/rotary_embedding.py @@ -0,0 +1,39 @@ +"""Roatary embedding. + +Stolen from lucidrains: +    https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +Explanation of roatary: +    https://blog.eleuther.ai/rotary-embeddings/ + +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): +    def __init__(self, dim: int): +        super().__init__() +        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) +        self.register_buffer("inv_freq", inv_freq) + +    def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: +        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) +        freqs = torch.einsum("i , j -> i j", t, self.inv_freq) +        emb = torch.cat((freqs, freqs), dim=-1) +        return emb[None, :, :] + + +def rotate_half(x: Tensor) -> Tensor: +    x = rearrange(x, "... (j d) -> ... j d", j=2) +    x1, x2 = x.unbind(dim=-2) +    return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: +    q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) +    return q, k |