From e388cd95c77d37a51324cff9d84a809421bf97d3 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 8 Apr 2021 23:38:03 +0200
Subject: Bug fixes word pieces

---
 text_recognizer/data/base_dataset.py               |   2 +-
 text_recognizer/data/iam.py                        |   1 -
 text_recognizer/data/iam_preprocessor.py           |  28 ++-
 text_recognizer/data/transforms.py                 |   6 +-
 text_recognizer/models/base.py                     |   6 +-
 .../training/experiments/image_transformer.yaml    |  72 ++++++++
 text_recognizer/training/run_experiment.py         | 201 +++++++++++++++++++++
 7 files changed, 301 insertions(+), 15 deletions(-)
 create mode 100644 text_recognizer/training/experiments/image_transformer.yaml
 create mode 100644 text_recognizer/training/run_experiment.py

(limited to 'text_recognizer')

diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index d00daaf..8d644d4 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -67,7 +67,7 @@ def convert_strings_to_labels(
     labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
     for i, string in enumerate(strings):
         tokens = list(string)
-        tokens = ["<s>", *tokens, "</s>"]
+        tokens = ["<s>", *tokens, "<e>"]
         for j, token in enumerate(tokens):
             labels[i, j] = mapping[token]
     return labels
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 01272ba..261c8d3 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -7,7 +7,6 @@ import zipfile
 
 from boltons.cacheutils import cachedproperty
 from loguru import logger
-from PIL import Image
 import toml
 
 from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 3844419..d85787e 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -47,8 +47,6 @@ def load_metadata(
 class Preprocessor:
     """A preprocessor for the IAM dataset."""
 
-    # TODO: add lower case only to when generating...
-
     def __init__(
         self,
         data_dir: Union[str, Path],
@@ -57,10 +55,12 @@ class Preprocessor:
         lexicon_path: Optional[Union[str, Path]] = None,
         use_words: bool = False,
         prepend_wordsep: bool = False,
+        special_tokens: Optional[List[str]] = None,
     ) -> None:
         self.wordsep = "▁"
         self._use_word = use_words
         self._prepend_wordsep = prepend_wordsep
+        self.special_tokens = special_tokens if special_tokens is not None else None
 
         self.data_dir = Path(data_dir)
 
@@ -88,6 +88,10 @@ class Preprocessor:
         else:
             self.lexicon = None
 
+        if self.special_tokens is not None:
+            self.tokens += self.special_tokens
+            self.graphemes += self.special_tokens
+
         self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
         self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
         self.num_features = num_features
@@ -115,21 +119,31 @@ class Preprocessor:
                     continue
                 self.text.append(example["text"].lower())
 
-    def to_index(self, line: str) -> torch.LongTensor:
-        """Converts text to a tensor of indices."""
+
+    def _to_index(self, line: str) -> torch.LongTensor:
+        if line in self.special_tokens:
+            return torch.LongTensor([self.tokens_to_index[line]])
         token_to_index = self.graphemes_to_index
         if self.lexicon is not None:
             if len(line) > 0:
                 # If the word is not found in the lexicon, fall back to letters.
-                line = [
+                tokens = [
                     t
                     for w in line.split(self.wordsep)
                     for t in self.lexicon.get(w, self.wordsep + w)
                 ]
             token_to_index = self.tokens_to_index
         if self._prepend_wordsep:
-            line = itertools.chain([self.wordsep], line)
-        return torch.LongTensor([token_to_index[t] for t in line])
+            tokens = itertools.chain([self.wordsep], tokens)
+        return torch.LongTensor([token_to_index[t] for t in tokens])
+
+    def to_index(self, line: str) -> torch.LongTensor:
+        """Converts text to a tensor of indices."""
+        if self.special_tokens is not None:
+            pattern = f"({'|'.join(self.special_tokens)})"
+            lines = list(filter(None, re.split(pattern, line)))
+            return torch.cat([self._to_index(l) for l in lines])
+        return self._to_index(line)
 
     def to_text(self, indices: List[int]) -> str:
         """Converts indices to text."""
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 616e236..297c953 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -23,12 +23,12 @@ class ToLower:
 class ToCharcters:
     """Converts integers to characters."""
 
-    def __init__(self) -> None:
-        self.mapping, _, _ = emnist_mapping()
+    def __init__(self, extra_symbols: Optional[List[str]] = None) -> None:
+        self.mapping, _, _ = emnist_mapping(extra_symbols)
 
     def __call__(self, y: Tensor) -> str:
         """Converts a Tensor to a str."""
-        return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁")
+        return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁")
 
 
 class WordPieces:
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 3c1919e..0928e6c 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -49,7 +49,7 @@ class LitBaseModel(pl.LightningModule):
             optimizer_class = getattr(torch.optim, self._optimizer.type)
         return optimizer_class(params=self.parameters(), **args)
 
-    def _configure_lr_scheduler(self) -> Dict[str, Any]:
+    def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]:
         """Configures the lr scheduler."""
         scheduler = {"monitor": self.monitor}
         args = {} or self._lr_scheduler.args
@@ -59,13 +59,13 @@ class LitBaseModel(pl.LightningModule):
 
         scheduler["scheduler"] = getattr(
             torch.optim.lr_scheduler, self._lr_scheduler.type
-        )(**args)
+            )(optimizer, **args)
         return scheduler
 
     def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
         """Configures optimizer and lr scheduler."""
         optimizer = self._configure_optimizer()
-        scheduler = self._configure_lr_scheduler()
+        scheduler = self._configure_lr_scheduler(optimizer)
 
         return [optimizer], [scheduler]
 
diff --git a/text_recognizer/training/experiments/image_transformer.yaml b/text_recognizer/training/experiments/image_transformer.yaml
new file mode 100644
index 0000000..bedcbb5
--- /dev/null
+++ b/text_recognizer/training/experiments/image_transformer.yaml
@@ -0,0 +1,72 @@
+seed: 4711
+
+network:
+        desc: null
+        type: ImageTransformer
+        args:
+                encoder:
+                        type: null
+                        args: null
+                num_decoder_layers: 4
+                hidden_dim: 256
+                num_heads: 4
+                expansion_dim: 1024
+                dropout_rate: 0.1
+                transformer_activation: glu
+
+model:
+        desc: null
+        type: LitTransformerModel
+        args:
+                optimizer:
+                        type: MADGRAD
+                        args:
+                                lr: 1.0e-2
+                                momentum: 0.9
+                                weight_decay: 0
+                                eps: 1.0e-6
+                lr_scheduler:
+                        type: CosineAnnealingLR
+                        args:
+                                T_max: 512
+                criterion:
+                        type: CrossEntropyLoss
+                        args:
+                                weight: None
+                                ignore_index: -100
+                                reduction: mean
+                monitor: val_loss
+                mapping: sentence_piece
+
+data:
+        desc: null
+        type: IAMExtendedParagraphs
+        args:
+                batch_size: 16
+                num_workers: 12
+                train_fraction: 0.8
+                augment: true
+
+callbacks:
+        - type: ModelCheckpoint
+          args:
+                  monitor: val_loss
+                  mode: min
+        - type: EarlyStopping
+          args:
+                  monitor: val_loss
+                  mode: min
+                  patience: 10
+
+trainer:
+        desc: null
+        args:
+                stochastic_weight_avg: true
+                auto_scale_batch_size: binsearch
+                gradient_clip_val: 0
+                fast_dev_run: false
+                gpus: 1
+                precision: 16
+                max_epochs: 512
+                terminate_on_nan: true
+                weights_summary: true
diff --git a/text_recognizer/training/run_experiment.py b/text_recognizer/training/run_experiment.py
new file mode 100644
index 0000000..ed1a947
--- /dev/null
+++ b/text_recognizer/training/run_experiment.py
@@ -0,0 +1,201 @@
+"""Script to run experiments."""
+from datetime import datetime
+import importlib
+from pathlib import Path
+from typing import Dict, List, Optional, Type
+
+import click
+from loguru import logger
+from omegaconf import DictConfig, OmegaConf
+import pytorch_lightning as pl
+import torch
+from torch import nn
+from torchsummary import summary
+from tqdm import tqdm
+import wandb
+
+
+SEED = 4711
+EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
+
+
+def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
+    """Configure the loguru logger for output to terminal and disk."""
+
+    def _get_level(verbose: int) -> str:
+        """Sets the logger level."""
+        levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"}
+        verbose = min(verbose, 2)
+        return levels[verbose]
+
+    # Remove default logger to get tqdm to work properly.
+    logger.remove()
+
+    # Fetch verbosity level.
+    level = _get_level(verbose)
+
+    logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
+    if log_dir is not None:
+        logger.add(
+            str(log_dir / "train.log"),
+            format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
+        )
+
+
+def _load_config(file_path: Path) -> DictConfig:
+    """Return experiment config."""
+    logger.info(f"Loading config from: {file_path}")
+    if not file_path.exists():
+        raise FileNotFoundError(f"Experiment config not found at: {file_path}")
+    return OmegaConf.load(file_path)
+
+
+def _import_class(module_and_class_name: str) -> type:
+    """Import class from module."""
+    module_name, class_name = module_and_class_name.rsplit(".", 1)
+    module = importlib.import_module(module_name)
+    return getattr(module, class_name)
+
+
+def _configure_callbacks(
+    callbacks: List[DictConfig],
+) -> List[Type[pl.callbacks.Callback]]:
+    """Configures lightning callbacks."""
+    pl_callbacks = [
+        getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks
+    ]
+    return pl_callbacks
+
+
+def _configure_logger(
+    network: Type[nn.Module], args: Dict, use_wandb: bool
+) -> Type[pl.loggers.LightningLoggerBase]:
+    """Configures lightning logger."""
+    if use_wandb:
+        pl_logger = pl.loggers.WandbLogger()
+        pl_logger.watch(network)
+        pl_logger.log_hyperparams(vars(args))
+        return pl_logger
+    return pl.logger.TensorBoardLogger("training/logs")
+
+
+def _save_best_weights(
+    callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
+) -> None:
+    """Saves the best model."""
+    model_checkpoint_callback = next(
+        callback
+        for callback in callbacks
+        if isinstance(callback, pl.callbacks.ModelCheckpoint)
+    )
+    best_model_path = model_checkpoint_callback.best_model_path
+    if best_model_path:
+        logger.info(f"Best model saved at: {best_model_path}")
+        if use_wandb:
+            logger.info("Uploading model to W&B...")
+            wandb.save(best_model_path)
+
+
+def _load_lit_model(
+    lit_model_class: type, network: Type[nn.Module], config: DictConfig
+) -> Type[pl.LightningModule]:
+    """Load lightning model."""
+    if config.load_checkpoint is not None:
+        logger.info(
+            f"Loading network weights from checkpoint: {config.load_checkpoint}"
+        )
+        return lit_model_class.load_from_checkpoint(
+            config.load_checkpoint, network=network, **config.model.args
+        )
+    return lit_model_class(network=network, **config.model.args)
+
+
+def run(
+    filename: str,
+    train: bool,
+    test: bool,
+    tune: bool,
+    use_wandb: bool,
+    verbose: int = 0,
+) -> None:
+    """Runs experiment."""
+
+    _configure_logging(None, verbose=verbose)
+    logger.info("Starting experiment...")
+
+    # Seed everything in the experiment.
+    logger.info(f"Seeding everthing with seed={SEED}")
+    pl.utilities.seed.seed_everything(SEED)
+
+    # Load config.
+    file_path = EXPERIMENTS_DIRNAME / filename
+    config = _load_config(file_path)
+
+    # Load classes.
+    data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
+    network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
+    lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
+
+    # Initialize data object and network.
+    data_module = data_module_class(**config.data.args)
+    network = network_class(**data_module.config(), **config.network.args)
+
+    # Load callback and logger.
+    callbacks = _configure_callbacks(config.callbacks)
+    pl_logger = _configure_logger(network, config, use_wandb)
+
+    # Load ligtning model.
+    lit_model = _load_lit_model(lit_model_class, network, config)
+
+    trainer = pl.Trainer(
+        **config.trainer.args,
+        callbacks=callbacks,
+        logger=pl_logger,
+        weigths_save_path="training/logs",
+    )
+
+    if tune:
+        logger.info(f"Tuning learning rate and batch size...")
+        trainer.tune(lit_model, datamodule=data_module)
+
+    if train:
+        logger.info(f"Training network...")
+        trainer.fit(lit_model, datamodule=data_module)
+
+    if test:
+        logger.info(f"Testing network...")
+        trainer.test(lit_model, datamodule=data_module)
+
+    _save_best_weights(callbacks, use_wandb)
+
+
+@click.command()
+@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
+@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
+@click.option(
+    "--tune", is_flag=True, help="If true, tune hyperparameters for training."
+)
+@click.option("--train", is_flag=True, help="If true, train the model.")
+@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-v", "--verbose", count=True)
+def cli(
+    experiment_config: str,
+    use_wandb: bool,
+    tune: bool,
+    train: bool,
+    test: bool,
+    verbose: int,
+) -> None:
+    """Run experiment."""
+    run(
+        filename=experiment_config,
+        train=train,
+        test=test,
+        tune=tune,
+        use_wandb=use_wandb,
+        verbose=verbose,
+    )
+
+
+if __name__ == "__main__":
+    cli()
-- 
cgit v1.2.3-70-g09d2