summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb90
-rw-r--r--text_recognizer/data/mapping.py8
-rw-r--r--text_recognizer/models/base.py1
-rw-r--r--text_recognizer/models/transformer.py5
-rw-r--r--text_recognizer/networks/image_transformer.py23
-rw-r--r--training/configs/image_transformer.yaml28
6 files changed, 123 insertions, 32 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 4b82034..cfa0ba5 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -25,12 +25,13 @@
" sys.path.append('..')\n",
"\n",
"from text_recognizer.data.iam_paragraphs import IAMParagraphs\n",
- "from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs"
+ "from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs\n",
+ "from text_recognizer.data.iam_extended_paragraphs import IAMExtendedParagraphs"
]
},
{
"cell_type": "code",
- "execution_count": 162,
+ "execution_count": 2,
"id": "726ac25b",
"metadata": {},
"outputs": [],
@@ -47,6 +48,65 @@
},
{
"cell_type": "code",
+ "execution_count": 3,
+ "id": "c6188bce",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-04-11 21:49:35.313 | INFO | text_recognizer.data.iam_paragraphs:setup:106 - Loading IAM paragraph regions and lines for None...\n",
+ "2021-04-11 21:49:51.802 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:77 - IAM Synthetic dataset steup for stage None\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IAM Original and Synthetic Paragraphs Dataset\n",
+ "Num classes: 84\n",
+ "Dims: (1, 576, 640)\n",
+ "Output dims: (682, 1)\n",
+ "Train/val/test sizes: 19942, 262, 231\n",
+ "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0099), tensor(0.0553), tensor(1.))\n",
+ "Train Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n",
+ "Test Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0284), tensor(0.0846), tensor(0.9373))\n",
+ "Test Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = IAMExtendedParagraphs()\n",
+ "dataset.prepare_data()\n",
+ "dataset.setup()\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "1b3c7bdd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1246.375"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "19942 / 16"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 4,
"id": "42501428",
"metadata": {},
@@ -152,7 +212,7 @@
},
{
"cell_type": "code",
- "execution_count": 165,
+ "execution_count": 5,
"id": "45649194",
"metadata": {},
"outputs": [],
@@ -163,7 +223,7 @@
},
{
"cell_type": "code",
- "execution_count": 166,
+ "execution_count": 6,
"id": "0fc13f9f",
"metadata": {},
"outputs": [],
@@ -181,6 +241,27 @@
},
{
"cell_type": "code",
+ "execution_count": 8,
+ "id": "fb0afccf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1004"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(processor.tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 167,
"id": "d08a0259",
"metadata": {},
@@ -435,7 +516,6 @@
}
],
"source": [
- "\n",
"# Testing\n",
"\n",
"for _ in range(5):\n",
diff --git a/text_recognizer/data/mapping.py b/text_recognizer/data/mapping.py
new file mode 100644
index 0000000..f0edf3f
--- /dev/null
+++ b/text_recognizer/data/mapping.py
@@ -0,0 +1,8 @@
+"""Mapping to and from word pieces."""
+from pathlib import Path
+
+
+class WordPieces:
+
+ def __init__(self) -> None:
+ pass
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 0928e6c..c6d5d73 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -60,6 +60,7 @@ class LitBaseModel(pl.LightningModule):
scheduler["scheduler"] = getattr(
torch.optim.lr_scheduler, self._lr_scheduler.type
)(optimizer, **args)
+
return scheduler
def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index b23685b..7dc1352 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple
+from typing import Dict, List, Optional, Union, Tuple, Type
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
@@ -19,7 +19,7 @@ class LitTransformerModel(LitBaseModel):
def __init__(
self,
- network: Type[nn, Module],
+ network: Type[nn.Module],
optimizer: Union[DictConfig, Dict],
lr_scheduler: Union[DictConfig, Dict],
criterion: Union[DictConfig, Dict],
@@ -27,7 +27,6 @@ class LitTransformerModel(LitBaseModel):
mapping: Optional[List[str]] = None,
) -> None:
super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index 9ed67a4..daededa 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -10,16 +10,15 @@ TODO: Local attention for lower layer in attention.
"""
import importlib
import math
-from typing import Dict, List, Union, Sequence, Tuple, Type
+from typing import Dict, Optional, Union, Sequence, Type
from einops import rearrange
from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
-import torchvision
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
from text_recognizer.networks.transformer import (
Decoder,
DecoderLayer,
@@ -28,6 +27,8 @@ from text_recognizer.networks.transformer import (
target_padding_mask,
)
+NUM_WORD_PIECES = 1000
+
class ImageTransformer(nn.Module):
def __init__(
@@ -35,7 +36,7 @@ class ImageTransformer(nn.Module):
input_shape: Sequence[int],
output_shape: Sequence[int],
encoder: Union[DictConfig, Dict],
- mapping: str,
+ vocab_size: Optional[int] = None,
num_decoder_layers: int = 4,
hidden_dim: int = 256,
num_heads: int = 4,
@@ -43,14 +44,9 @@ class ImageTransformer(nn.Module):
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
) -> None:
- # Configure mapping
- mapping, inverse_mapping = self._configure_mapping(mapping)
- self.vocab_size = len(mapping)
+ self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
self.hidden_dim = hidden_dim
self.max_output_length = output_shape[0]
- self.start_index = inverse_mapping["<s>"]
- self.end_index = inverse_mapping["<e>"]
- self.pad_index = inverse_mapping["<p>"]
# Image backbone
self.encoder = self._configure_encoder(encoder)
@@ -107,13 +103,6 @@ class ImageTransformer(nn.Module):
encoder_class = getattr(network_module, encoder.type)
return encoder_class(**encoder.args)
- def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]:
- """Configures mapping."""
- # TODO: Fix me!!!
- if mapping == "emnist":
- mapping, inverse_mapping, _ = emnist_mapping()
- return mapping, inverse_mapping
-
def encode(self, image: Tensor) -> Tensor:
"""Extracts image features with backbone.
diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml
index bedcbb5..88c05c2 100644
--- a/training/configs/image_transformer.yaml
+++ b/training/configs/image_transformer.yaml
@@ -1,7 +1,7 @@
seed: 4711
network:
- desc: null
+ desc: Configuration of the PyTorch neural network.
type: ImageTransformer
args:
encoder:
@@ -15,20 +15,24 @@ network:
transformer_activation: glu
model:
- desc: null
+ desc: Configuration of the PyTorch Lightning model.
type: LitTransformerModel
args:
optimizer:
type: MADGRAD
args:
- lr: 1.0e-2
+ lr: 1.0e-3
momentum: 0.9
weight_decay: 0
eps: 1.0e-6
lr_scheduler:
- type: CosineAnnealingLR
+ type: OneCycle
args:
- T_max: 512
+ interval: &interval step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 512
+ steps_per_epoch: 1246 # num_samples / batch_size
criterion:
type: CrossEntropyLoss
args:
@@ -39,7 +43,7 @@ model:
mapping: sentence_piece
data:
- desc: null
+ desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
batch_size: 16
@@ -52,6 +56,16 @@ callbacks:
args:
monitor: val_loss
mode: min
+ - type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+ - type: LearningRateMonitor
+ args:
+ logging_interval: *interval
- type: EarlyStopping
args:
monitor: val_loss
@@ -59,7 +73,7 @@ callbacks:
patience: 10
trainer:
- desc: null
+ desc: Configuration of the PyTorch Lightning Trainer.
args:
stochastic_weight_avg: true
auto_scale_batch_size: binsearch