diff options
Diffstat (limited to 'text_recognizer')
5 files changed, 4 insertions, 79 deletions
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py index c054b41..7ee6720 100644 --- a/text_recognizer/networks/text_decoder.py +++ b/text_recognizer/networks/text_decoder.py @@ -1,5 +1,5 @@ """Text decoder.""" -from typing import Type +from typing import Optional, Type import torch from torch import Tensor, nn @@ -16,7 +16,6 @@ class TextDecoder(nn.Module): num_classes: int, pad_index: Tensor, decoder: Decoder, - token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.hidden_dim = hidden_dim @@ -26,7 +25,6 @@ class TextDecoder(nn.Module): self.token_embedding = nn.Embedding( num_embeddings=self.num_classes, embedding_dim=self.hidden_dim ) - self.token_pos_embedding = token_pos_embedding self.to_logits = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes ) @@ -52,7 +50,6 @@ class TextDecoder(nn.Module): tokens = tokens.long() mask = tokens != self.pad_index tokens = self.token_embedding(tokens) - tokens = tokens + self.token_pos_embedding(tokens) tokens = self.decoder(x=tokens, context=img_features, mask=mask) logits = ( tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1) diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index 741f5b3..09d2dce 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -1,13 +1,11 @@ """Transformer decoder module.""" from copy import deepcopy -from typing import Optional, Type +from typing import Optional from torch import Tensor, nn -from text_recognizer.networks.transformer.attention import Attention from text_recognizer.networks.transformer.decoder_block import DecoderBlock from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding -from text_recognizer.networks.transformer.ff import FeedForward class Decoder(nn.Module): @@ -18,7 +16,7 @@ class Decoder(nn.Module): depth: int, dim: int, block: DecoderBlock, - rotary_embedding: Optional[RotaryEmbedding] = None, + rotary_embedding: RotaryEmbedding, ) -> None: super().__init__() self.depth = depth diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py index 2dc4ddf..f7ae454 100644 --- a/text_recognizer/networks/transformer/decoder_block.py +++ b/text_recognizer/networks/transformer/decoder_block.py @@ -30,9 +30,9 @@ class DecoderBlock(nn.Module): def forward( self, x: Tensor, + rotary_embedding: RotaryEmbedding, context: Optional[Tensor] = None, mask: Optional[Tensor] = None, - rotary_embedding: Optional[RotaryEmbedding] = None, ) -> Tensor: """Applies decoder block on input signals.""" x = x + self.attn(self.ln_attn(x), mask=mask, rotary_embedding=rotary_embedding) diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py deleted file mode 100644 index 9274b55..0000000 --- a/text_recognizer/networks/transformer/embeddings/absolute.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Absolute positional embedding.""" - -import torch -import torch.nn.functional as F -from einops import rearrange -from torch import nn - - -def l2norm(t, groups=1): - t = rearrange(t, "... (g d) -> ... g d", g=groups) - t = F.normalize(t, p=2, dim=-1) - return rearrange(t, "... g d -> ... (g d)") - - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len, l2norm_embed=False): - super().__init__() - self.scale = dim**-0.5 if not l2norm_embed else 1.0 - self.max_seq_len = max_seq_len - self.l2norm_embed = l2norm_embed - self.emb = nn.Embedding(max_seq_len, dim) - - def forward(self, x, pos=None): - seq_len = x.shape[1] - assert ( - seq_len <= self.max_seq_len - ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" - - if pos is None: - pos = torch.arange(seq_len, device=x.device) - - pos_emb = self.emb(pos) - pos_emb = pos_emb * self.scale - return l2norm(pos_emb) if self.l2norm_embed else pos_emb diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py deleted file mode 100644 index 28da7a1..0000000 --- a/text_recognizer/networks/transformer/embeddings/fourier.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Fourier positional embedding.""" -import numpy as np -import torch -from torch import Tensor, nn - - -class PositionalEncoding(nn.Module): - """Encodes a sense of distance or time for transformer networks.""" - - def __init__(self, dim: int, dropout_rate: float, max_len: int = 1000) -> None: - super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) - pe = self.make_pe(dim, max_len) - self.register_buffer("pe", pe) - - @staticmethod - def make_pe(hidden_dim: int, max_len: int) -> Tensor: - """Returns positional encoding.""" - pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim) - ) - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(1) - return pe - - def forward(self, x: Tensor) -> Tensor: - """Encodes the tensor with a postional embedding.""" - # [T, B, D] - if x.shape[2] != self.pe.shape[2]: - raise ValueError("x shape does not match pe in the 3rd dim.") - x = x + self.pe[: x.shape[0]] - return self.dropout(x) |