diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/networks/transformer/embeddings/fourier.py | 50 | 
1 files changed, 2 insertions, 48 deletions
| diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py index c50afc3..ade589c 100644 --- a/text_recognizer/networks/transformer/embeddings/fourier.py +++ b/text_recognizer/networks/transformer/embeddings/fourier.py @@ -1,5 +1,4 @@ -"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" -from einops import repeat +"""Fourier positional embedding."""  import numpy as np  import torch  from torch import nn @@ -35,51 +34,6 @@ class PositionalEncoding(nn.Module):          """Encodes the tensor with a postional embedding."""          # [T, B, D]          if x.shape[2] != self.pe.shape[2]: -            raise ValueError(f"x shape does not match pe in the 3rd dim.") +            raise ValueError("x shape does not match pe in the 3rd dim.")          x = x + self.pe[: x.shape[0]]          return self.dropout(x) - - -class PositionalEncoding2D(nn.Module): -    """Positional encodings for feature maps.""" - -    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: -        super().__init__() -        if hidden_dim % 2 != 0: -            raise ValueError(f"Embedding depth {hidden_dim} is not even!") -        self.hidden_dim = hidden_dim -        pe = self.make_pe(hidden_dim, max_h, max_w) -        self.register_buffer("pe", pe) - -    @staticmethod -    def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: -        """Returns 2d postional encoding.""" -        pe_h = PositionalEncoding.make_pe( -            hidden_dim // 2, max_len=max_h -        )  # [H, 1, D // 2] -        pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - -        pe_w = PositionalEncoding.make_pe( -            hidden_dim // 2, max_len=max_w -        )  # [W, 1, D // 2] -        pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) - -        pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W] -        return pe - -    def forward(self, x: Tensor) -> Tensor: -        """Adds 2D postional encoding to input tensor.""" -        # Assumes x hase shape [B, D, H, W] -        if x.shape[1] != self.pe.shape[0]: -            raise ValueError("Hidden dimensions does not match.") -        x += self.pe[:, : x.shape[2], : x.shape[3]] -        return x - - -def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: -    """Returns causal target mask.""" -    trg_pad_mask = (trg != pad_index)[:, None, None] -    trg_len = trg.shape[1] -    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() -    trg_mask = trg_pad_mask & trg_sub_mask -    return trg_mask |