summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/__init__.py3
-rw-r--r--text_recognizer/data/emnist_lines.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py15
-rw-r--r--text_recognizer/data/iam_paragraphs.py23
-rw-r--r--text_recognizer/data/iam_preprocessor.py1
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py7
-rw-r--r--text_recognizer/data/mappings.py16
-rw-r--r--text_recognizer/data/transforms.py14
-rw-r--r--text_recognizer/models/__init__.py3
-rw-r--r--text_recognizer/models/base.py9
-rw-r--r--text_recognizer/models/vqvae.py70
-rw-r--r--text_recognizer/networks/__init__.py2
-rw-r--r--text_recognizer/networks/cnn_transformer.py257
-rw-r--r--text_recognizer/networks/image_transformer.py165
-rw-r--r--text_recognizer/networks/residual_network.py6
-rw-r--r--text_recognizer/networks/transducer/transducer.py7
-rw-r--r--text_recognizer/networks/vqvae/decoder.py20
-rw-r--r--text_recognizer/networks/vqvae/encoder.py30
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py5
19 files changed, 318 insertions, 337 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py
index 9a42fa9..3599a8b 100644
--- a/text_recognizer/data/__init__.py
+++ b/text_recognizer/data/__init__.py
@@ -2,3 +2,6 @@
from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset
from .base_data_module import BaseDataModule, load_and_print_info
from .download_utils import download_dataset
+from .iam_paragraphs import IAMParagraphs
+from .iam_synthetic_paragraphs import IAMSyntheticParagraphs
+from .iam_extended_paragraphs import IAMExtendedParagraphs
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 72665d0..9650198 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -57,8 +57,8 @@ class EMNISTLines(BaseDataModule):
self.num_test = num_test
self.emnist = EMNIST()
- # TODO: fix mapping
self.mapping = self.emnist.mapping
+
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ IMAGE_X_PADDING
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index d2529b4..2380660 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -10,18 +10,27 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
class IAMExtendedParagraphs(BaseDataModule):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
super().__init__(batch_size, num_workers)
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment,
+ batch_size,
+ num_workers,
+ train_fraction,
+ augment,
+ word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment,
+ batch_size,
+ num_workers,
+ train_fraction,
+ augment,
+ word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index f588587..62c44f9 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
from loguru import logger
import numpy as np
-from PIL import Image, ImageFile, ImageOps
-import torch
+from PIL import Image, ImageOps
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
@@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import (
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam import IAM
+from text_recognizer.data.transforms import WordPiece
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
@@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
super().__init__(batch_size, num_workers)
- # TODO: pass in transform and target transform
- # TODO: pass in mapping
self.augment = augment
+ self.word_pieces = word_pieces
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
@@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
+ target_transform=get_target_transform(self.word_pieces)
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
+ """Transform emnist characters to word pieces."""
+ return transforms.Compose([WordPiece()]) if word_pieces else None
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 60f8a9f..b5f72da 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -89,6 +89,7 @@ class Preprocessor:
self.lexicon = None
if self.special_tokens is not None:
+ self.special_tokens += ("#", "*")
self.tokens += self.special_tokens
self.graphemes += self.special_tokens
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 9f1bd12..4ccc5c2 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -18,6 +18,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.iam_paragraphs import (
get_dataset_properties,
get_transform,
+ get_target_transform,
NEW_LINE_TOKEN,
IAMParagraphs,
IMAGE_SCALE_FACTOR,
@@ -41,12 +42,13 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
- super().__init__(batch_size, num_workers, train_fraction, augment)
+ super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces)
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
@@ -95,6 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
transform=get_transform(
image_shape=self.dims[1:], augment=self.augment
),
+ target_transform=get_target_transform(self.word_pieces)
)
def __repr__(self) -> str:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index cfa0ec7..f4016ba 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -8,7 +8,7 @@ import torch
from torch import Tensor
from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
+from text_recognizer.data.iam_preprocessor import Preprocessor
class AbstractMapping(ABC):
@@ -57,14 +57,14 @@ class EmnistMapping(AbstractMapping):
class WordPieceMapping(EmnistMapping):
def __init__(
self,
- num_features: int,
- tokens: str,
- lexicon: str,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Sequence[str]] = ("\n", ),
) -> None:
super().__init__(extra_symbols)
self.wordpiece_processor = self._configure_wordpiece_processor(
@@ -78,8 +78,8 @@ class WordPieceMapping(EmnistMapping):
extra_symbols,
)
+ @staticmethod
def _configure_wordpiece_processor(
- self,
num_features: int,
tokens: str,
lexicon: str,
@@ -90,7 +90,7 @@ class WordPieceMapping(EmnistMapping):
extra_symbols: Optional[Sequence[str]],
) -> Preprocessor:
data_dir = (
- (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb")
+ (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb")
if data_dir is None
else Path(data_dir)
)
@@ -138,6 +138,6 @@ class WordPieceMapping(EmnistMapping):
return self.wordpiece_processor.to_index(text)
def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
- text = self.mapping.get_text(x)
+ text = "".join([self.mapping[i] for i in x])
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index f53df64..8d1bedd 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence
from torch import Tensor
-from text_recognizer.datasets.mappings import WordPieceMapping
+from text_recognizer.data.mappings import WordPieceMapping
class WordPiece:
@@ -12,14 +12,15 @@ class WordPiece:
def __init__(
self,
- num_features: int,
- tokens: str,
- lexicon: str,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Sequence[str]] = ("\n",),
+ max_len: int = 192,
) -> None:
self.mapping = WordPieceMapping(
num_features,
@@ -31,6 +32,7 @@ class WordPiece:
special_tokens,
extra_symbols,
)
+ self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)
+ return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len]
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index e69de29..5ac2510 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -0,0 +1,3 @@
+"""PyTorch Lightning models modules."""
+from .transformer import LitTransformerModel
+from .vqvae import LitVQVAEModel
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index aeda039..88ffde6 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -40,6 +40,15 @@ class LitBaseModel(pl.LightningModule):
args = {} or criterion.args
return getattr(nn, criterion.type)(**args)
+ def optimizer_zero_grad(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer: Type[torch.optim.Optimizer],
+ optimizer_idx: int,
+ ) -> None:
+ optimizer.zero_grad(set_to_none=True)
+
def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
args = {} or self._optimizer.args
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
new file mode 100644
index 0000000..ef2213c
--- /dev/null
+++ b/text_recognizer/models/vqvae.py
@@ -0,0 +1,70 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Any, Dict, Union, Tuple, Type
+
+from omegaconf import DictConfig, OmegaConf
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+import wandb
+
+from text_recognizer.models.base import LitBaseModel
+
+
+class LitVQVAEModel(LitBaseModel):
+ """A PyTorch Lightning model for transformer networks."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ optimizer: Union[DictConfig, Dict],
+ lr_scheduler: Union[DictConfig, Dict],
+ criterion: Union[DictConfig, Dict],
+ monitor: str = "val_loss",
+ *args: Any,
+ **kwargs: Dict,
+ ) -> None:
+ super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.network.predict(data)
+
+ def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None:
+ """Logs prediction on image with wandb."""
+ try:
+ self.logger.experiment.log(
+ {
+ "val_pred_examples": [
+ wandb.Image(data[0]),
+ wandb.Image(reconstructions[0]),
+ ]
+ }
+ )
+ except AttributeError:
+ pass
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self.log("train_loss", loss)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self.log("val_loss", loss, prog_bar=True)
+ self._log_prediction(data, reconstructions)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self._log_prediction(data, reconstructions)
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 979149f..41fd43f 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,2 +1,2 @@
"""Network modules"""
-from .image_transformer import ImageTransformer
+from .vqvae import VQVAE
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index 9150b55..e23a15d 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -1,158 +1,165 @@
-"""A CNN-Transformer for image to text recognition."""
-from typing import Dict, Optional, Tuple
+"""A Transformer with a cnn backbone.
+
+The network encodes a image with a convolutional backbone to a latent representation,
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+spatial information. The resulting feature are then set to a transformer decoder
+together with the target tokens.
+
+TODO: Local attention for lower layer in attention.
+
+"""
+import importlib
+import math
+from typing import Dict, Optional, Union, Sequence, Type
from einops import rearrange
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.util import configure_backbone
+from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
+from text_recognizer.networks.transformer import (
+ Decoder,
+ DecoderLayer,
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
+NUM_WORD_PIECES = 1000
-class CNNTransformer(nn.Module):
- """CNN+Transfomer for image to sequence prediction."""
+class CNNTransformer(nn.Module):
def __init__(
self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- adaptive_pool_dim: Tuple,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- backbone: str,
- backbone_args: Optional[Dict] = None,
- activation: str = "gelu",
- pool_kernel: Optional[Tuple[int, int]] = None,
+ input_shape: Sequence[int],
+ output_shape: Sequence[int],
+ encoder: Union[DictConfig, Dict],
+ vocab_size: Optional[int] = None,
+ num_decoder_layers: int = 4,
+ hidden_dim: int = 256,
+ num_heads: int = 4,
+ expansion_dim: int = 1024,
+ dropout_rate: float = 0.1,
+ transformer_activation: str = "glu",
) -> None:
- super().__init__()
- self.trg_pad_index = trg_pad_index
- self.vocab_size = vocab_size
- self.backbone = configure_backbone(backbone, backbone_args)
-
- if pool_kernel is not None:
- self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
- else:
- self.max_pool = None
-
- self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
-
- self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.pos_dropout = nn.Dropout(p=dropout_rate)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
-
- nn.init.normal_(self.character_embedding.weight, std=0.02)
-
- self.adaptive_pool = (
- nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ self.vocab_size = (
+ NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
)
+ self.hidden_dim = hidden_dim
+ self.max_output_length = output_shape[0]
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
+ # Image backbone
+ self.encoder = self._configure_encoder(encoder)
+ self.feature_map_encoding = PositionalEncoding2D(
+ hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
)
- self.head = nn.Sequential(
- # nn.Linear(hidden_dim, hidden_dim * 2),
- # activation_function(activation),
- nn.Linear(hidden_dim, vocab_size),
- )
+ # Target token embedding
+ self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ # Transformer decoder
+ self.decoder = Decoder(
+ decoder_layer=DecoderLayer(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion_dim=expansion_dim,
+ dropout_rate=dropout_rate,
+ activation=transformer_activation,
+ ),
+ num_layers=num_decoder_layers,
+ norm=nn.LayerNorm(hidden_dim),
)
- def extract_image_features(self, src: Tensor) -> Tensor:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+ # Classification head
+ self.head = nn.Linear(hidden_dim, self.vocab_size)
- Args:
- src (Tensor): Input tensor.
+ # Initialize weights
+ self._init_weights()
- Returns:
- Tensor: A input src to the transformer.
+ def _init_weights(self) -> None:
+ """Initialize network weights."""
+ self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
+ self.head.bias.data.zero_()
+ self.head.weight.data.uniform_(-0.1, 0.1)
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
-
- src = self.backbone(src)
-
- if self.max_pool is not None:
- src = self.max_pool(src)
-
- if self.adaptive_pool is not None and len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b w c h")
- src = self.adaptive_pool(src)
- src = src.squeeze(3)
- elif len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b (h w) c")
+ nn.init.kaiming_normal_(
+ self.feature_map_encoding.weight.data,
+ a=0,
+ mode="fan_out",
+ nonlinearity="relu",
+ )
+ if self.feature_map_encoding.bias is not None:
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(
+ self.feature_map_encoding.weight.data
+ )
+ bound = 1 / math.sqrt(fan_out)
+ nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+
+ @staticmethod
+ def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
+ encoder = OmegaConf.create(encoder)
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_class = getattr(network_module, encoder.type)
+ return encoder_class(**encoder.args)
+
+ def encode(self, image: Tensor) -> Tensor:
+ """Extracts image features with backbone.
- b, t, _ = src.shape
+ Args:
+ image (Tensor): Image(s) of handwritten text.
- src += self.src_position_embedding[:, :t]
- src = self.pos_dropout(src)
+ Retuns:
+ Tensor: Image features.
- return src
+ Shapes:
+ - image: :math: `(B, C, H, W)`
+ - latent: :math: `(B, T, C)`
- def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
+ """
+ # Extract image features.
+ image_features = self.encoder(image)
- Args:
- trg (Tensor): Target tensor.
+ # Add 2d encoding to the feature maps.
+ image_features = self.feature_map_encoding(image_features)
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+ # Collapse features maps height and width.
+ image_features = rearrange(image_features, "b c h w -> b (h w) c")
+ return image_features
- """
- trg = self.character_embedding(trg.long())
+ def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
+ """Decodes image features with transformer decoder."""
+ trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
+ trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
trg = self.trg_position_encoding(trg)
- return trg
-
- def decode_image_features(
- self, image_features: Tensor, trg: Optional[Tensor] = None
- ) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(image_features, trg, trg_mask=trg_mask)
-
+ out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
logits = self.head(out)
return logits
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- image_features = self.extract_image_features(x)
- logits = self.decode_image_features(image_features, trg)
- return logits
+ def predict(self, image: Tensor) -> Tensor:
+ """Transcribes text in image(s)."""
+ bsz = image.shape[0]
+ image_features = self.encode(image)
+
+ output_tokens = (
+ (torch.ones((bsz, self.max_output_length)) * self.pad_index)
+ .type_as(image)
+ .long()
+ )
+ output_tokens[:, 0] = self.start_index
+ for i in range(1, self.max_output_length):
+ trg = output_tokens[:, :i]
+ output = self.decode(image_features, trg)
+ output = torch.argmax(output, dim=-1)
+ output_tokens[:, i] = output[-1:]
+
+ # Set all tokens after end token to be padding.
+ for i in range(1, self.max_output_length):
+ indices = output_tokens[:, i - 1] == self.end_index | (
+ output_tokens[:, i - 1] == self.pad_index
+ )
+ output_tokens[indices, i] = self.pad_index
+
+ return output_tokens
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
deleted file mode 100644
index a6aaca4..0000000
--- a/text_recognizer/networks/image_transformer.py
+++ /dev/null
@@ -1,165 +0,0 @@
-"""A Transformer with a cnn backbone.
-
-The network encodes a image with a convolutional backbone to a latent representation,
-i.e. feature maps. A 2d positional encoding is applied to the feature maps for
-spatial information. The resulting feature are then set to a transformer decoder
-together with the target tokens.
-
-TODO: Local attention for lower layer in attention.
-
-"""
-import importlib
-import math
-from typing import Dict, Optional, Union, Sequence, Type
-
-from einops import rearrange
-from omegaconf import DictConfig, OmegaConf
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
-from text_recognizer.networks.transformer import (
- Decoder,
- DecoderLayer,
- PositionalEncoding,
- PositionalEncoding2D,
- target_padding_mask,
-)
-
-NUM_WORD_PIECES = 1000
-
-
-class ImageTransformer(nn.Module):
- def __init__(
- self,
- input_shape: Sequence[int],
- output_shape: Sequence[int],
- encoder: Union[DictConfig, Dict],
- vocab_size: Optional[int] = None,
- num_decoder_layers: int = 4,
- hidden_dim: int = 256,
- num_heads: int = 4,
- expansion_dim: int = 1024,
- dropout_rate: float = 0.1,
- transformer_activation: str = "glu",
- ) -> None:
- self.vocab_size = (
- NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
- )
- self.hidden_dim = hidden_dim
- self.max_output_length = output_shape[0]
-
- # Image backbone
- self.encoder = self._configure_encoder(encoder)
- self.feature_map_encoding = PositionalEncoding2D(
- hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
- )
-
- # Target token embedding
- self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
-
- # Transformer decoder
- self.decoder = Decoder(
- decoder_layer=DecoderLayer(
- hidden_dim=hidden_dim,
- num_heads=num_heads,
- expansion_dim=expansion_dim,
- dropout_rate=dropout_rate,
- activation=transformer_activation,
- ),
- num_layers=num_decoder_layers,
- norm=nn.LayerNorm(hidden_dim),
- )
-
- # Classification head
- self.head = nn.Linear(hidden_dim, self.vocab_size)
-
- # Initialize weights
- self._init_weights()
-
- def _init_weights(self) -> None:
- """Initialize network weights."""
- self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
- self.head.bias.data.zero_()
- self.head.weight.data.uniform_(-0.1, 0.1)
-
- nn.init.kaiming_normal_(
- self.feature_map_encoding.weight.data,
- a=0,
- mode="fan_out",
- nonlinearity="relu",
- )
- if self.feature_map_encoding.bias is not None:
- _, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.feature_map_encoding.weight.data
- )
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
-
- @staticmethod
- def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
- encoder = OmegaConf.create(encoder)
- network_module = importlib.import_module("text_recognizer.networks")
- encoder_class = getattr(network_module, encoder.type)
- return encoder_class(**encoder.args)
-
- def encode(self, image: Tensor) -> Tensor:
- """Extracts image features with backbone.
-
- Args:
- image (Tensor): Image(s) of handwritten text.
-
- Retuns:
- Tensor: Image features.
-
- Shapes:
- - image: :math: `(B, C, H, W)`
- - latent: :math: `(B, T, C)`
-
- """
- # Extract image features.
- image_features = self.encoder(image)
-
- # Add 2d encoding to the feature maps.
- image_features = self.feature_map_encoding(image_features)
-
- # Collapse features maps height and width.
- image_features = rearrange(image_features, "b c h w -> b (h w) c")
- return image_features
-
- def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
- """Decodes image features with transformer decoder."""
- trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
- trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = self.trg_position_encoding(trg)
- out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
- logits = self.head(out)
- return logits
-
- def predict(self, image: Tensor) -> Tensor:
- """Transcribes text in image(s)."""
- bsz = image.shape[0]
- image_features = self.encode(image)
-
- output_tokens = (
- (torch.ones((bsz, self.max_output_length)) * self.pad_index)
- .type_as(image)
- .long()
- )
- output_tokens[:, 0] = self.start_index
- for i in range(1, self.max_output_length):
- trg = output_tokens[:, :i]
- output = self.decode(image_features, trg)
- output = torch.argmax(output, dim=-1)
- output_tokens[:, i] = output[-1:]
-
- # Set all tokens after end token to be padding.
- for i in range(1, self.max_output_length):
- indices = output_tokens[:, i - 1] == self.end_index | (
- output_tokens[:, i - 1] == self.pad_index
- )
- output_tokens[indices, i] = self.pad_index
-
- return output_tokens
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index c33f419..da7553d 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
@@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d):
def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
"""3x3 convolution with batch norm."""
- conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ conv3x3 = partial(
+ Conv2dAuto,
+ kernel_size=3,
+ bias=False,
+ )
return nn.Sequential(
conv3x3(in_channels, out_channels, *args, **kwargs),
nn.BatchNorm2d(out_channels),
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index d7e3d08..b10f93a 100644
--- a/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -392,7 +392,12 @@ def load_transducer_loss(
transitions = gtn.load(str(processed_path / transitions))
preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
)
num_tokens = preprocessor.num_tokens
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..93a1e43 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,7 +44,12 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
)
def _build_decompression_block(
@@ -72,8 +77,10 @@ class Decoder(nn.Module):
)
)
- if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ if self.upsampling and i < len(self.upsampling):
+ modules.append(
+ nn.Upsample(size=self.upsampling[i]),
+ )
if dropout is not None:
modules.append(dropout)
@@ -102,7 +109,12 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ nn.Conv2d(
+ self.embedding_dim,
+ channels[0],
+ kernel_size=1,
+ stride=1,
+ )
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..b0cceed 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,5 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-from typing import List, Optional, Tuple, Type
+from typing import Sequence, Optional, Tuple, Type
import torch
from torch import nn
@@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -36,9 +39,9 @@ class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
+ channels: Sequence[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
@@ -77,12 +80,12 @@ class Encoder(nn.Module):
self.num_embeddings, self.embedding_dim, self.beta
)
+ @staticmethod
def _build_compression_block(
- self,
in_channels: int,
channels: int,
- kernel_sizes: List[int],
- strides: List[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
activation: Type[nn.Module],
dropout: Optional[Type[nn.Module]],
) -> nn.ModuleList:
@@ -109,8 +112,8 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: int,
- kernel_sizes: List[int],
- strides: List[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
num_residual_layers: int,
activation: Type[nn.Module],
dropout: Optional[Type[nn.Module]],
@@ -135,7 +138,12 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ nn.Conv2d(
+ channels[-1],
+ self.embedding_dim,
+ kernel_size=1,
+ stride=1,
+ )
)
return nn.Sequential(*encoder)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 50448b4..1f08e5e 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,8 +1,7 @@
"""The VQ-VAE."""
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple
-import torch
from torch import nn
from torch import Tensor
@@ -25,6 +24,8 @@ class VQVAE(nn.Module):
beta: float = 0.25,
activation: str = "leaky_relu",
dropout_rate: float = 0.0,
+ *args: Any,
+ **kwargs: Dict,
) -> None:
super().__init__()