summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/text_decoder.py5
-rw-r--r--text_recognizer/networks/transformer/decoder.py6
-rw-r--r--text_recognizer/networks/transformer/decoder_block.py2
-rw-r--r--text_recognizer/networks/transformer/embeddings/absolute.py34
-rw-r--r--text_recognizer/networks/transformer/embeddings/fourier.py36
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml16
6 files changed, 9 insertions, 90 deletions
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py
index c054b41..7ee6720 100644
--- a/text_recognizer/networks/text_decoder.py
+++ b/text_recognizer/networks/text_decoder.py
@@ -1,5 +1,5 @@
"""Text decoder."""
-from typing import Type
+from typing import Optional, Type
import torch
from torch import Tensor, nn
@@ -16,7 +16,6 @@ class TextDecoder(nn.Module):
num_classes: int,
pad_index: Tensor,
decoder: Decoder,
- token_pos_embedding: Type[nn.Module],
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
@@ -26,7 +25,6 @@ class TextDecoder(nn.Module):
self.token_embedding = nn.Embedding(
num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
)
- self.token_pos_embedding = token_pos_embedding
self.to_logits = nn.Linear(
in_features=self.hidden_dim, out_features=self.num_classes
)
@@ -52,7 +50,6 @@ class TextDecoder(nn.Module):
tokens = tokens.long()
mask = tokens != self.pad_index
tokens = self.token_embedding(tokens)
- tokens = tokens + self.token_pos_embedding(tokens)
tokens = self.decoder(x=tokens, context=img_features, mask=mask)
logits = (
tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1)
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index 741f5b3..09d2dce 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -1,13 +1,11 @@
"""Transformer decoder module."""
from copy import deepcopy
-from typing import Optional, Type
+from typing import Optional
from torch import Tensor, nn
-from text_recognizer.networks.transformer.attention import Attention
from text_recognizer.networks.transformer.decoder_block import DecoderBlock
from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
-from text_recognizer.networks.transformer.ff import FeedForward
class Decoder(nn.Module):
@@ -18,7 +16,7 @@ class Decoder(nn.Module):
depth: int,
dim: int,
block: DecoderBlock,
- rotary_embedding: Optional[RotaryEmbedding] = None,
+ rotary_embedding: RotaryEmbedding,
) -> None:
super().__init__()
self.depth = depth
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py
index 2dc4ddf..f7ae454 100644
--- a/text_recognizer/networks/transformer/decoder_block.py
+++ b/text_recognizer/networks/transformer/decoder_block.py
@@ -30,9 +30,9 @@ class DecoderBlock(nn.Module):
def forward(
self,
x: Tensor,
+ rotary_embedding: RotaryEmbedding,
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
- rotary_embedding: Optional[RotaryEmbedding] = None,
) -> Tensor:
"""Applies decoder block on input signals."""
x = x + self.attn(self.ln_attn(x), mask=mask, rotary_embedding=rotary_embedding)
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py
deleted file mode 100644
index 9274b55..0000000
--- a/text_recognizer/networks/transformer/embeddings/absolute.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Absolute positional embedding."""
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-from torch import nn
-
-
-def l2norm(t, groups=1):
- t = rearrange(t, "... (g d) -> ... g d", g=groups)
- t = F.normalize(t, p=2, dim=-1)
- return rearrange(t, "... g d -> ... (g d)")
-
-
-class AbsolutePositionalEmbedding(nn.Module):
- def __init__(self, dim, max_seq_len, l2norm_embed=False):
- super().__init__()
- self.scale = dim**-0.5 if not l2norm_embed else 1.0
- self.max_seq_len = max_seq_len
- self.l2norm_embed = l2norm_embed
- self.emb = nn.Embedding(max_seq_len, dim)
-
- def forward(self, x, pos=None):
- seq_len = x.shape[1]
- assert (
- seq_len <= self.max_seq_len
- ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}"
-
- if pos is None:
- pos = torch.arange(seq_len, device=x.device)
-
- pos_emb = self.emb(pos)
- pos_emb = pos_emb * self.scale
- return l2norm(pos_emb) if self.l2norm_embed else pos_emb
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py
deleted file mode 100644
index 28da7a1..0000000
--- a/text_recognizer/networks/transformer/embeddings/fourier.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""Fourier positional embedding."""
-import numpy as np
-import torch
-from torch import Tensor, nn
-
-
-class PositionalEncoding(nn.Module):
- """Encodes a sense of distance or time for transformer networks."""
-
- def __init__(self, dim: int, dropout_rate: float, max_len: int = 1000) -> None:
- super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
- pe = self.make_pe(dim, max_len)
- self.register_buffer("pe", pe)
-
- @staticmethod
- def make_pe(hidden_dim: int, max_len: int) -> Tensor:
- """Returns positional encoding."""
- pe = torch.zeros(max_len, hidden_dim)
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
- )
-
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(1)
- return pe
-
- def forward(self, x: Tensor) -> Tensor:
- """Encodes the tensor with a postional embedding."""
- # [T, B, D]
- if x.shape[2] != self.pe.shape[2]:
- raise ValueError("x shape does not match pe in the 3rd dim.")
- x = x + self.pe[: x.shape[0]]
- return self.dropout(x)
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 4e921f2..3392cd6 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -83,7 +83,7 @@ network:
pixel_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
- dim: &hidden_dim 384
+ dim: *dim
axial_shape: [7, 128]
axial_dims: [192, 192]
decoder:
@@ -96,19 +96,19 @@ network:
dim: *dim
depth: 6
block:
- _target_: text_recognizer.networks.transformer.decoder_block.\
- DecoderBlock
+ _target_: "text_recognizer.networks.transformer.decoder_block.\
+ DecoderBlock"
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
- num_heads: 10
+ num_heads: 8
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
- num_heads: 10
+ num_heads: 8
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -125,12 +125,6 @@ network:
rotary_embedding:
_target_: text_recognizer.networks.transformer.RotaryEmbedding
dim: 64
- token_pos_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *dim
- dropout_rate: 0.1
- max_len: *max_output_len
model:
max_output_len: *max_output_len