summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_paragraphs.py13
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/data/mappings.py12
-rw-r--r--text_recognizer/data/transforms.py4
-rw-r--r--text_recognizer/models/transformer.py2
-rw-r--r--text_recognizer/networks/__init__.py2
-rw-r--r--text_recognizer/networks/backbones/__init__.py2
-rw-r--r--text_recognizer/networks/backbones/efficientnet.py145
-rw-r--r--text_recognizer/networks/cnn_transformer.py37
-rw-r--r--text_recognizer/networks/coat/__init__.py0
-rw-r--r--text_recognizer/networks/coat/factor_attention.py9
-rw-r--r--text_recognizer/networks/coat/patch_embedding.py38
-rw-r--r--text_recognizer/networks/coat/positional_encodings.py76
-rw-r--r--text_recognizer/networks/residual_network.py6
-rw-r--r--text_recognizer/networks/transducer/transducer.py7
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py6
-rw-r--r--text_recognizer/networks/transformer/rotary_embedding.py39
-rw-r--r--text_recognizer/networks/vqvae/decoder.py18
-rw-r--r--text_recognizer/networks/vqvae/encoder.py12
20 files changed, 247 insertions, 195 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 2380660..0a30a42 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -19,18 +19,10 @@ class IAMExtendedParagraphs(BaseDataModule):
super().__init__(batch_size, num_workers)
self.iam_paragraphs = IAMParagraphs(
- batch_size,
- num_workers,
- train_fraction,
- augment,
- word_pieces,
+ batch_size, num_workers, train_fraction, augment, word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size,
- num_workers,
- train_fraction,
- augment,
- word_pieces,
+ batch_size, num_workers, train_fraction, augment, word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 62c44f9..24409bc 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {
- "min": crop_shapes.min(axis=0),
- "max": crop_shapes.max(axis=0),
- },
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1,
- shear=(-10, 10),
- interpolation=InterpolationMode.BILINEAR,
+ degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+
def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
"""Transform emnist characters to word pieces."""
return transforms.Compose([WordPiece()]) if word_pieces else None
+
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
return PROCESSED_DATA_DIRNAME / split / "_labels.json"
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 4ccc5c2..78e6c05 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -97,7 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
transform=get_transform(
image_shape=self.dims[1:], augment=self.augment
),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
def __repr__(self) -> str:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index f4016ba..190febe 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -58,13 +58,13 @@ class WordPieceMapping(EmnistMapping):
def __init__(
self,
num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n", ),
+ extra_symbols: Optional[Sequence[str]] = ("\n",),
) -> None:
super().__init__(extra_symbols)
self.wordpiece_processor = self._configure_wordpiece_processor(
@@ -90,7 +90,13 @@ class WordPieceMapping(EmnistMapping):
extra_symbols: Optional[Sequence[str]],
) -> Preprocessor:
data_dir = (
- (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb")
+ (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
if data_dir is None
else Path(data_dir)
)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 8d1bedd..d0f1f35 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -13,7 +13,7 @@ class WordPiece:
def __init__(
self,
num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
@@ -35,4 +35,4 @@ class WordPiece:
self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len]
+ return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len]
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 7dc1352..8dd4db2 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -39,7 +39,7 @@ class LitTransformerModel(LitBaseModel):
def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
"""Configure mapping."""
# TODO: Fix me!!!
- mapping, inverse_mapping, _ = emnist_mapping()
+ mapping, inverse_mapping, _ = emnist_mapping(["\n"])
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
pad_index = inverse_mapping["<p>"]
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 41fd43f..63b43b2 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,2 +1,4 @@
"""Network modules"""
+from .backbones import EfficientNet
from .vqvae import VQVAE
+from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/backbones/__init__.py
new file mode 100644
index 0000000..25aed0e
--- /dev/null
+++ b/text_recognizer/networks/backbones/__init__.py
@@ -0,0 +1,2 @@
+"""Vision backbones."""
+from .efficientnet import EfficientNet
diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/backbones/efficientnet.py
new file mode 100644
index 0000000..61dea77
--- /dev/null
+++ b/text_recognizer/networks/backbones/efficientnet.py
@@ -0,0 +1,145 @@
+"""Efficient net b0 implementation."""
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class ConvNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ groups: int = 1,
+ ) -> None:
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels),
+ nn.SiLU(inplace=True),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.block(x)
+
+
+class SqueezeExcite(nn.Module):
+ def __init__(self, in_channels: int, reduce_dim: int) -> None:
+ super().__init__()
+ self.se = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1]
+ nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x * self.se(x)
+
+
+class InvertedResidulaBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ expand_ratio: float,
+ reduction: int = 4,
+ survival_prob: float = 0.8,
+ ) -> None:
+ super().__init__()
+ self.survival_prob = survival_prob
+ self.use_residual = in_channels == out_channels and stride == 1
+ hidden_dim = in_channels * expand_ratio
+ self.expand = in_channels != hidden_dim
+ reduce_dim = in_channels // reduction
+
+ if self.expand:
+ self.expand_conv = ConvNorm(
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1
+ )
+
+ self.conv = nn.Sequential(
+ ConvNorm(
+ hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim
+ ),
+ SqueezeExcite(hidden_dim, reduce_dim),
+ nn.Conv2d(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels),
+ )
+
+ def stochastic_depth(self, x: Tensor) -> Tensor:
+ if not self.training:
+ return x
+
+ binary_tensor = (
+ torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
+ )
+ return torch.div(x, self.survival_prob) * binary_tensor
+
+ def forward(self, x: Tensor) -> Tensor:
+ out = self.expand_conv(x) if self.expand else x
+ if self.use_residual:
+ return self.stochastic_depth(self.conv(out)) + x
+ return self.conv(out)
+
+
+class EfficientNet(nn.Module):
+ """Efficient net b0 backbone."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.base_model = [
+ # expand_ratio, channels, repeats, stride, kernel_size
+ [1, 16, 1, 1, 3],
+ [6, 24, 2, 2, 3],
+ [6, 40, 2, 2, 5],
+ [6, 80, 3, 2, 3],
+ [6, 112, 3, 1, 5],
+ [6, 192, 4, 2, 5],
+ [6, 320, 1, 1, 3],
+ ]
+
+ self.backbone = self._build_b0()
+
+ def _build_b0(self) -> nn.Sequential:
+ in_channels = 32
+ layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)]
+
+ for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model:
+ for i in range(repeats):
+ layers.append(
+ InvertedResidulaBlock(
+ in_channels,
+ out_channels,
+ expand_ratio=expand_ratio,
+ stride=stride if i == 0 else 1,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ in_channels = out_channels
+ layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.backbone(x)
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index e23a15d..d42c29d 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000
class CNNTransformer(nn.Module):
def __init__(
self,
- input_shape: Sequence[int],
- output_shape: Sequence[int],
+ input_dim: Sequence[int],
+ output_dims: Sequence[int],
encoder: Union[DictConfig, Dict],
vocab_size: Optional[int] = None,
num_decoder_layers: int = 4,
@@ -43,22 +43,29 @@ class CNNTransformer(nn.Module):
expansion_dim: int = 1024,
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
+ *args,
+ **kwargs,
) -> None:
+ super().__init__()
self.vocab_size = (
NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
)
+ self.pad_index = 3 # TODO: fix me
self.hidden_dim = hidden_dim
- self.max_output_length = output_shape[0]
+ self.max_output_length = output_dims[0]
# Image backbone
self.encoder = self._configure_encoder(encoder)
+ self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
self.feature_map_encoding = PositionalEncoding2D(
- hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
+ hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
)
# Target token embedding
self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.trg_position_encoding = PositionalEncoding(
+ hidden_dim, dropout_rate, max_len=output_dims[0]
+ )
# Transformer decoder
self.decoder = Decoder(
@@ -86,24 +93,25 @@ class CNNTransformer(nn.Module):
self.head.weight.data.uniform_(-0.1, 0.1)
nn.init.kaiming_normal_(
- self.feature_map_encoding.weight.data,
+ self.encoder_proj.weight.data,
a=0,
mode="fan_out",
nonlinearity="relu",
)
- if self.feature_map_encoding.bias is not None:
+ if self.encoder_proj.bias is not None:
_, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.feature_map_encoding.weight.data
+ self.encoder_proj.weight.data
)
bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+ nn.init.normal_(self.encoder_proj.bias, -bound, bound)
@staticmethod
def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
encoder = OmegaConf.create(encoder)
+ args = encoder.args or {}
network_module = importlib.import_module("text_recognizer.networks")
encoder_class = getattr(network_module, encoder.type)
- return encoder_class(**encoder.args)
+ return encoder_class(**args)
def encode(self, image: Tensor) -> Tensor:
"""Extracts image features with backbone.
@@ -121,6 +129,7 @@ class CNNTransformer(nn.Module):
"""
# Extract image features.
image_features = self.encoder(image)
+ image_features = self.encoder_proj(image_features)
# Add 2d encoding to the feature maps.
image_features = self.feature_map_encoding(image_features)
@@ -133,11 +142,19 @@ class CNNTransformer(nn.Module):
"""Decodes image features with transformer decoder."""
trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = rearrange(trg, "b t d -> t b d")
trg = self.trg_position_encoding(trg)
+ trg = rearrange(trg, "t b d -> b t d")
out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
logits = self.head(out)
return logits
+ def forward(self, image: Tensor, trg: Tensor) -> Tensor:
+ image_features = self.encode(image)
+ output = self.decode(image_features, trg)
+ output = rearrange(output, "b t c -> b c t")
+ return output
+
def predict(self, image: Tensor) -> Tensor:
"""Transcribes text in image(s)."""
bsz = image.shape[0]
diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/text_recognizer/networks/coat/__init__.py
+++ /dev/null
diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py
deleted file mode 100644
index f91c5ef..0000000
--- a/text_recognizer/networks/coat/factor_attention.py
+++ /dev/null
@@ -1,9 +0,0 @@
-"""Factorized attention with convolutional relative positional encodings."""
-from torch import nn
-
-
-class FactorAttention(nn.Module):
- """Factorized attention with relative positional encodings."""
- def __init__(self, dim: int, num_heads: int) -> None:
- pass
-
diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py
deleted file mode 100644
index 3b7b76a..0000000
--- a/text_recognizer/networks/coat/patch_embedding.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""Patch embedding for images and feature maps."""
-from typing import Sequence, Tuple
-
-from einops import rearrange
-from loguru import logger
-from torch import nn
-from torch import Tensor
-
-
-class PatchEmbedding(nn.Module):
- """Patch embedding of images."""
-
- def __init__(
- self,
- image_shape: Sequence[int],
- patch_size: int = 16,
- in_channels: int = 1,
- embedding_dim: int = 512,
- ) -> None:
- if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0:
- logger.error(
- f"Image shape {image_shape} not divisable by patch size {patch_size}"
- )
-
- self.patch_size = patch_size
- self.embedding = nn.Conv2d(
- in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
- )
- self.norm = nn.LayerNorm(embedding_dim)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
- """Embeds image or feature maps with patch embedding."""
- _, _, h, w = x.shape
- h_out, w_out = h // self.patch_size, w // self.patch_size
- x = self.embedding(x)
- x = rearrange(x, "b c h w -> b (h w) c")
- x = self.norm(x)
- return x, (h_out, w_out)
diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py
deleted file mode 100644
index 925db04..0000000
--- a/text_recognizer/networks/coat/positional_encodings.py
+++ /dev/null
@@ -1,76 +0,0 @@
-"""Positional encodings for input sequence to transformer."""
-from typing import Dict, Union, Tuple
-
-from einops import rearrange
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class RelativeEncoding(nn.Module):
- """Relative positional encoding."""
- def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None:
- super().__init__()
- self.windows = {windows: heads} if isinstance(windows, int) else windows
- self.heads = list(self.windows.values())
- self.channel_heads = [head * channels for head in self.heads]
- self.convs = nn.ModuleList([
- nn.Conv2d(in_channels=head * channels,
- out_channels=head * channels,
- kernel_shape=window,
- padding=window // 2,
- dilation=1,
- groups=head * channels,
- ) for window, head in self.windows.items()])
-
- def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor:
- """Applies relative positional encoding."""
- b, heads, hw, c = q.shape
- h, w = shape
- if hw != h * w:
- logger.exception(f"Query width {hw} neq to height x width {h * w}")
- raise ValueError
-
- v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w)
- v = torch.split(v, self.channel_heads, dim=1)
- v = [conv(x) for conv, x in zip(self.convs, v)]
- v = torch.cat(v, dim=1)
- v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads)
-
- encoding = q * v
- zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device)
- encoding = torch.cat((zeros, encoding), dim=2)
- return encoding
-
-
-class PositionalEncoding(nn.Module):
- """Convolutional positional encoding."""
- def __init__(self, dim: int, k: int = 3) -> None:
- super().__init__()
- self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim)
-
- def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
- """Applies convolutional encoding."""
- _, hw, _ = x.shape
- h, w = shape
-
- if hw != h * w:
- logger.exception(f"Query width {hw} neq to height x width {h * w}")
- raise ValueError
-
- # Depthwise convolution.
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
- x = self.encode(x) + x
- x = rearrange(x, "b c h w -> b (h w) c")
- return x
-
-
-
-
-
-
-
-
-
-
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index da7553d..c33f419 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
@@ -20,11 +20,7 @@ class Conv2dAuto(nn.Conv2d):
def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
"""3x3 convolution with batch norm."""
- conv3x3 = partial(
- Conv2dAuto,
- kernel_size=3,
- bias=False,
- )
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
return nn.Sequential(
conv3x3(in_channels, out_channels, *args, **kwargs),
nn.BatchNorm2d(out_channels),
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index b10f93a..d7e3d08 100644
--- a/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -392,12 +392,7 @@ def load_transducer_loss(
transitions = gtn.load(str(processed_path / transitions))
preprocessor = Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
+ data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
)
num_tokens = preprocessor.num_tokens
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 5874e97..c50afc3 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
+ # [T, B, D]
+ if x.shape[2] != self.pe.shape[2]:
+ raise ValueError(f"x shape does not match pe in the 3rd dim.")
+ x = x + self.pe[: x.shape[0]]
return self.dropout(x)
@@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module):
pe = self.make_pe(hidden_dim, max_h, max_w)
self.register_buffer("pe", pe)
+ @staticmethod
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
pe_h = PositionalEncoding.make_pe(
diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py
new file mode 100644
index 0000000..5e80572
--- /dev/null
+++ b/text_recognizer/networks/transformer/rotary_embedding.py
@@ -0,0 +1,39 @@
+"""Roatary embedding.
+
+Stolen from lucidrains:
+ https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+
+Explanation of roatary:
+ https://blog.eleuther.ai/rotary-embeddings/
+
+"""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb[None, :, :]
+
+
+def rotate_half(x: Tensor) -> Tensor:
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
+ q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
+ return q, k
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 93a1e43..32de912 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,12 +44,7 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- activation,
- dropout,
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
)
def _build_decompression_block(
@@ -78,9 +73,7 @@ class Decoder(nn.Module):
)
if self.upsampling and i < len(self.upsampling):
- modules.append(
- nn.Upsample(size=self.upsampling[i]),
- )
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
if dropout is not None:
modules.append(dropout)
@@ -109,12 +102,7 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(
- self.embedding_dim,
- channels[0],
- kernel_size=1,
- stride=1,
- )
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index b0cceed..65801df 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -138,12 +135,7 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(
- channels[-1],
- self.embedding_dim,
- kernel_size=1,
- stride=1,
- )
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
)
return nn.Sequential(*encoder)