From e388cd95c77d37a51324cff9d84a809421bf97d3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 8 Apr 2021 23:38:03 +0200 Subject: Bug fixes word pieces --- README.md | 5 +- notebooks/03-look-at-iam-paragraphs.ipynb | 226 +++++++++++++++++++-- text_recognizer/data/base_dataset.py | 2 +- text_recognizer/data/iam.py | 1 - text_recognizer/data/iam_preprocessor.py | 28 ++- text_recognizer/data/transforms.py | 6 +- text_recognizer/models/base.py | 6 +- .../training/experiments/image_transformer.yaml | 72 +++++++ text_recognizer/training/run_experiment.py | 201 ++++++++++++++++++ training/experiments/image_transformer.yaml | 72 ------- training/run_experiment.py | 201 ------------------ 11 files changed, 515 insertions(+), 305 deletions(-) create mode 100644 text_recognizer/training/experiments/image_transformer.yaml create mode 100644 text_recognizer/training/run_experiment.py delete mode 100644 training/experiments/image_transformer.yaml delete mode 100644 training/run_experiment.py diff --git a/README.md b/README.md index c30ee03..dac7e98 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,11 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb - [x] build_transitions.py - [x] transform that encodes iam targets to wordpieces - [x] transducer loss function -- [ ] Train with word pieces +- [ ] Train with word pieces + - [ ] Pad word pieces index to same length - [ ] Local attention in first layer of transformer - [ ] Halonet encoder -- [ ] Implement CPC +- [ ] Implement CPC - [ ] https://arxiv.org/pdf/1905.09272.pdf - [ ] https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html?highlight=byol diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index df92f99..4b82034 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "id": "6ce2519f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", @@ -39,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 162, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -56,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "42501428", "metadata": {}, "outputs": [ @@ -64,7 +55,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-04-03 21:55:37.196 | INFO | text_recognizer.data.iam_paragraphs:setup:104 - Loading IAM paragraph regions and lines for None...\n" + "2021-04-08 21:48:18.431 | INFO | text_recognizer.data.iam_paragraphs:setup:106 - Loading IAM paragraph regions and lines for None...\n" ] }, { @@ -76,7 +67,7 @@ "Input dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", "Train/val/test sizes: 1046, 262, 231\n", - "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0358), tensor(0.1021), tensor(1.))\n", + "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0371), tensor(0.1049), tensor(1.))\n", "Train Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n", "Test Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0284), tensor(0.0846), tensor(0.9373))\n", "Test Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n", @@ -91,6 +82,211 @@ "print(dataset)" ] }, + { + "cell_type": "code", + "execution_count": 163, + "id": "0cf22683", + "metadata": {}, + "outputs": [], + "source": [ + "x, y = dataset.data_train[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "id": "98dd0ee6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1, 33, 47, 44, 66, 51, 40, 59, 59, 44, 57, 66, 43, 54, 66, 53, 54, 59,\n", + " 66, 57, 44, 46, 40, 57, 43, 66, 59, 47, 44, 52, 58, 44, 51, 61, 44, 58,\n", + " 66, 40, 58, 66, 44, 63, 55, 44, 57, 59, 83, 40, 43, 61, 48, 58, 44, 57,\n", + " 58, 76, 66, 41, 60, 59, 66, 40, 57, 44, 66, 55, 57, 44, 55, 40, 57, 44,\n", + " 43, 66, 59, 54, 66, 58, 44, 44, 50, 66, 54, 60, 59, 66, 59, 47, 44, 83,\n", + " 40, 55, 55, 57, 54, 55, 57, 48, 40, 59, 44, 66, 58, 54, 60, 57, 42, 44,\n", + " 58, 66, 54, 45, 66, 48, 53, 45, 54, 57, 52, 40, 59, 48, 54, 53, 66, 54,\n", + " 57, 66, 40, 43, 61, 48, 42, 44, 78, 83, 33, 54, 62, 40, 57, 43, 58, 66,\n", + " 59, 47, 44, 66, 44, 53, 43, 66, 54, 45, 66, 5, 13, 9, 10, 76, 66, 26,\n", + " 57, 78, 66, 17, 40, 53, 48, 44, 51, 66, 20, 57, 40, 53, 59, 76, 83, 40,\n", + " 53, 66, 18, 52, 55, 51, 54, 64, 44, 44, 66, 31, 44, 51, 40, 59, 48, 54,\n", + " 53, 58, 66, 28, 45, 45, 48, 42, 44, 57, 66, 54, 45, 83, 31, 54, 51, 51,\n", + " 58, 77, 31, 54, 64, 42, 44, 66, 25, 59, 43, 78, 66, 40, 53, 43, 66, 40,\n", + " 66, 52, 44, 52, 41, 44, 57, 66, 54, 45, 66, 59, 47, 44, 83, 36, 54, 57,\n", + " 50, 44, 57, 58, 71, 66, 18, 43, 60, 42, 40, 59, 48, 54, 53, 40, 51, 66,\n", + " 14, 58, 58, 54, 42, 48, 40, 59, 48, 54, 53, 76, 66, 58, 60, 41, 52, 48,\n", + " 59, 59, 44, 43, 83, 59, 54, 66, 59, 47, 44, 66, 25, 54, 57, 43, 66, 29,\n", + " 57, 54, 61, 54, 58, 59, 66, 54, 45, 66, 20, 51, 40, 58, 46, 54, 62, 76,\n", + " 66, 17, 57, 78, 66, 14, 53, 43, 57, 44, 62, 83, 21, 54, 54, 43, 76, 66,\n", + " 40, 66, 42, 54, 55, 64, 66, 54, 45, 66, 47, 48, 58, 66, 57, 44, 55, 54,\n", + " 57, 59, 66, 54, 53, 66, 40, 53, 66, 44, 53, 56, 60, 48, 57, 64, 83, 47,\n", + " 44, 66, 47, 40, 43, 66, 52, 40, 43, 44, 66, 48, 53, 59, 54, 66, 59, 47,\n", + " 44, 66, 55, 57, 54, 41, 51, 44, 52, 58, 66, 59, 47, 40, 59, 66, 41, 44,\n", + " 58, 44, 59, 83, 54, 51, 43, 44, 57, 66, 62, 54, 57, 50, 44, 57, 58, 66,\n", + " 40, 53, 43, 66, 59, 47, 44, 66, 44, 45, 45, 44, 42, 59, 58, 66, 54, 45,\n", + " 66, 57, 44, 59, 48, 57, 44, 52, 44, 53, 59, 83, 21, 60, 46, 44, 53, 59,\n", + " 54, 41, 51, 44, 57, 66, 31, 54, 46, 44, 57, 2, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])" + ] + }, + "execution_count": 164, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "45649194", + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.data.iam_preprocessor import Preprocessor\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "0fc13f9f", + "metadata": {}, + "outputs": [], + "source": [ + "processor = Preprocessor(\n", + " data_dir=Path(\"../data/downloaded/iam/iamdb\"),\n", + " num_features=1000,\n", + " lexicon_path=Path(\"../data/processed/iam_lines/iamdb_1kwp_lex_1000.txt\"),\n", + " tokens_path=Path(\"../data/processed/iam_lines/iamdb_1kwp_tokens_1000.txt\"),\n", + " use_words=True,\n", + " prepend_wordsep=False,\n", + " special_tokens=[\"\", \"\", \"

\", \"\\n\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "id": "d08a0259", + "metadata": {}, + "outputs": [], + "source": [ + "t = convert_y_label_to_string(y, dataset.mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "a16a2cb7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"The latter do not regard themselves as expert\\nadvisers, but are prepared to seek out the\\nappropriate sources of information or advice.\\nTowards the end of 1956, Mr. Daniel Grant,\\nan Employee Relations Officer of\\nRolls-Royce Ltd. and a member of the\\nWorkers' Educational Association, submitted\\nto the Lord Provost of Glasgow, Dr. Andrew\\nHood, a copy of his report on an enquiry\\nhe had made into the problems that beset\\nolder workers and the effects of retirement\\nHugentobler Roger\"" + ] + }, + "execution_count": 168, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "id": "c7a33b2d", + "metadata": {}, + "outputs": [], + "source": [ + "ii = processor.to_index(t.replace(\" \", \"▁\").lower())" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "id": "4e0a22f4", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([247])" + ] + }, + "execution_count": 171, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ii.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "id": "bc1c5ffb", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([247])" + ] + }, + "execution_count": 172, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ii.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "id": "8b7b0373", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"▁the▁latter▁do▁not▁regard▁themselves▁as▁expert\\n▁advisers,▁but▁are▁prepared▁to▁seek▁out▁the\\n▁appropriate▁sources▁of▁information▁or▁advice.\\n▁towards▁the▁end▁of▁1956,▁mr.▁daniel▁grant,\\n▁an▁employee▁relations▁officer▁of\\n▁rolls-royce▁ltd.▁and▁a▁member▁of▁the\\n▁workers'▁educational▁association,▁submitted\\n▁to▁the▁lord▁provost▁of▁glasgow,▁dr.▁andrew\\n▁hood,▁a▁copy▁of▁his▁report▁on▁an▁enquiry\\n▁he▁had▁made▁into▁the▁problems▁that▁beset\\n▁older▁workers▁and▁the▁effects▁of▁retirement\\n▁hugentobler▁roger\"" + ] + }, + "execution_count": 176, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "processor.to_text(ii)" + ] + }, { "cell_type": "code", "execution_count": 4, diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index d00daaf..8d644d4 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -67,7 +67,7 @@ def convert_strings_to_labels( labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) - tokens = ["", *tokens, ""] + tokens = ["", *tokens, ""] for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 01272ba..261c8d3 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -7,7 +7,6 @@ import zipfile from boltons.cacheutils import cachedproperty from loguru import logger -from PIL import Image import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 3844419..d85787e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -47,8 +47,6 @@ def load_metadata( class Preprocessor: """A preprocessor for the IAM dataset.""" - # TODO: add lower case only to when generating... - def __init__( self, data_dir: Union[str, Path], @@ -57,10 +55,12 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, + special_tokens: Optional[List[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep + self.special_tokens = special_tokens if special_tokens is not None else None self.data_dir = Path(data_dir) @@ -88,6 +88,10 @@ class Preprocessor: else: self.lexicon = None + if self.special_tokens is not None: + self.tokens += self.special_tokens + self.graphemes += self.special_tokens + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} self.num_features = num_features @@ -115,21 +119,31 @@ class Preprocessor: continue self.text.append(example["text"].lower()) - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" + + def _to_index(self, line: str) -> torch.LongTensor: + if line in self.special_tokens: + return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: if len(line) > 0: # If the word is not found in the lexicon, fall back to letters. - line = [ + tokens = [ t for w in line.split(self.wordsep) for t in self.lexicon.get(w, self.wordsep + w) ] token_to_index = self.tokens_to_index if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) + tokens = itertools.chain([self.wordsep], tokens) + return torch.LongTensor([token_to_index[t] for t in tokens]) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + if self.special_tokens is not None: + pattern = f"({'|'.join(self.special_tokens)})" + lines = list(filter(None, re.split(pattern, line))) + return torch.cat([self._to_index(l) for l in lines]) + return self._to_index(line) def to_text(self, indices: List[int]) -> str: """Converts indices to text.""" diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 616e236..297c953 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -23,12 +23,12 @@ class ToLower: class ToCharcters: """Converts integers to characters.""" - def __init__(self) -> None: - self.mapping, _, _ = emnist_mapping() + def __init__(self, extra_symbols: Optional[List[str]] = None) -> None: + self.mapping, _, _ = emnist_mapping(extra_symbols) def __call__(self, y: Tensor) -> str: """Converts a Tensor to a str.""" - return "".join([self.mapping(int(i)) for i in y]).strip("

").replace(" ", "▁") + return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁") class WordPieces: diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 3c1919e..0928e6c 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class LitBaseModel(pl.LightningModule): optimizer_class = getattr(torch.optim, self._optimizer.type) return optimizer_class(params=self.parameters(), **args) - def _configure_lr_scheduler(self) -> Dict[str, Any]: + def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]: """Configures the lr scheduler.""" scheduler = {"monitor": self.monitor} args = {} or self._lr_scheduler.args @@ -59,13 +59,13 @@ class LitBaseModel(pl.LightningModule): scheduler["scheduler"] = getattr( torch.optim.lr_scheduler, self._lr_scheduler.type - )(**args) + )(optimizer, **args) return scheduler def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler() + scheduler = self._configure_lr_scheduler(optimizer) return [optimizer], [scheduler] diff --git a/text_recognizer/training/experiments/image_transformer.yaml b/text_recognizer/training/experiments/image_transformer.yaml new file mode 100644 index 0000000..bedcbb5 --- /dev/null +++ b/text_recognizer/training/experiments/image_transformer.yaml @@ -0,0 +1,72 @@ +seed: 4711 + +network: + desc: null + type: ImageTransformer + args: + encoder: + type: null + args: null + num_decoder_layers: 4 + hidden_dim: 256 + num_heads: 4 + expansion_dim: 1024 + dropout_rate: 0.1 + transformer_activation: glu + +model: + desc: null + type: LitTransformerModel + args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-2 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: 512 + criterion: + type: CrossEntropyLoss + args: + weight: None + ignore_index: -100 + reduction: mean + monitor: val_loss + mapping: sentence_piece + +data: + desc: null + type: IAMExtendedParagraphs + args: + batch_size: 16 + num_workers: 12 + train_fraction: 0.8 + augment: true + +callbacks: + - type: ModelCheckpoint + args: + monitor: val_loss + mode: min + - type: EarlyStopping + args: + monitor: val_loss + mode: min + patience: 10 + +trainer: + desc: null + args: + stochastic_weight_avg: true + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 512 + terminate_on_nan: true + weights_summary: true diff --git a/text_recognizer/training/run_experiment.py b/text_recognizer/training/run_experiment.py new file mode 100644 index 0000000..ed1a947 --- /dev/null +++ b/text_recognizer/training/run_experiment.py @@ -0,0 +1,201 @@ +"""Script to run experiments.""" +from datetime import datetime +import importlib +from pathlib import Path +from typing import Dict, List, Optional, Type + +import click +from loguru import logger +from omegaconf import DictConfig, OmegaConf +import pytorch_lightning as pl +import torch +from torch import nn +from torchsummary import summary +from tqdm import tqdm +import wandb + + +SEED = 4711 +EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" + + +def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: + """Configure the loguru logger for output to terminal and disk.""" + + def _get_level(verbose: int) -> str: + """Sets the logger level.""" + levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"} + verbose = min(verbose, 2) + return levels[verbose] + + # Remove default logger to get tqdm to work properly. + logger.remove() + + # Fetch verbosity level. + level = _get_level(verbose) + + logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) + if log_dir is not None: + logger.add( + str(log_dir / "train.log"), + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", + ) + + +def _load_config(file_path: Path) -> DictConfig: + """Return experiment config.""" + logger.info(f"Loading config from: {file_path}") + if not file_path.exists(): + raise FileNotFoundError(f"Experiment config not found at: {file_path}") + return OmegaConf.load(file_path) + + +def _import_class(module_and_class_name: str) -> type: + """Import class from module.""" + module_name, class_name = module_and_class_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +def _configure_callbacks( + callbacks: List[DictConfig], +) -> List[Type[pl.callbacks.Callback]]: + """Configures lightning callbacks.""" + pl_callbacks = [ + getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks + ] + return pl_callbacks + + +def _configure_logger( + network: Type[nn.Module], args: Dict, use_wandb: bool +) -> Type[pl.loggers.LightningLoggerBase]: + """Configures lightning logger.""" + if use_wandb: + pl_logger = pl.loggers.WandbLogger() + pl_logger.watch(network) + pl_logger.log_hyperparams(vars(args)) + return pl_logger + return pl.logger.TensorBoardLogger("training/logs") + + +def _save_best_weights( + callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool +) -> None: + """Saves the best model.""" + model_checkpoint_callback = next( + callback + for callback in callbacks + if isinstance(callback, pl.callbacks.ModelCheckpoint) + ) + best_model_path = model_checkpoint_callback.best_model_path + if best_model_path: + logger.info(f"Best model saved at: {best_model_path}") + if use_wandb: + logger.info("Uploading model to W&B...") + wandb.save(best_model_path) + + +def _load_lit_model( + lit_model_class: type, network: Type[nn.Module], config: DictConfig +) -> Type[pl.LightningModule]: + """Load lightning model.""" + if config.load_checkpoint is not None: + logger.info( + f"Loading network weights from checkpoint: {config.load_checkpoint}" + ) + return lit_model_class.load_from_checkpoint( + config.load_checkpoint, network=network, **config.model.args + ) + return lit_model_class(network=network, **config.model.args) + + +def run( + filename: str, + train: bool, + test: bool, + tune: bool, + use_wandb: bool, + verbose: int = 0, +) -> None: + """Runs experiment.""" + + _configure_logging(None, verbose=verbose) + logger.info("Starting experiment...") + + # Seed everything in the experiment. + logger.info(f"Seeding everthing with seed={SEED}") + pl.utilities.seed.seed_everything(SEED) + + # Load config. + file_path = EXPERIMENTS_DIRNAME / filename + config = _load_config(file_path) + + # Load classes. + data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") + network_class = _import_class(f"text_recognizer.networks.{config.network.type}") + lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") + + # Initialize data object and network. + data_module = data_module_class(**config.data.args) + network = network_class(**data_module.config(), **config.network.args) + + # Load callback and logger. + callbacks = _configure_callbacks(config.callbacks) + pl_logger = _configure_logger(network, config, use_wandb) + + # Load ligtning model. + lit_model = _load_lit_model(lit_model_class, network, config) + + trainer = pl.Trainer( + **config.trainer.args, + callbacks=callbacks, + logger=pl_logger, + weigths_save_path="training/logs", + ) + + if tune: + logger.info(f"Tuning learning rate and batch size...") + trainer.tune(lit_model, datamodule=data_module) + + if train: + logger.info(f"Training network...") + trainer.fit(lit_model, datamodule=data_module) + + if test: + logger.info(f"Testing network...") + trainer.test(lit_model, datamodule=data_module) + + _save_best_weights(callbacks, use_wandb) + + +@click.command() +@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") +@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") +@click.option( + "--tune", is_flag=True, help="If true, tune hyperparameters for training." +) +@click.option("--train", is_flag=True, help="If true, train the model.") +@click.option("--test", is_flag=True, help="If true, test the model.") +@click.option("-v", "--verbose", count=True) +def cli( + experiment_config: str, + use_wandb: bool, + tune: bool, + train: bool, + test: bool, + verbose: int, +) -> None: + """Run experiment.""" + run( + filename=experiment_config, + train=train, + test=test, + tune=tune, + use_wandb=use_wandb, + verbose=verbose, + ) + + +if __name__ == "__main__": + cli() diff --git a/training/experiments/image_transformer.yaml b/training/experiments/image_transformer.yaml deleted file mode 100644 index bedcbb5..0000000 --- a/training/experiments/image_transformer.yaml +++ /dev/null @@ -1,72 +0,0 @@ -seed: 4711 - -network: - desc: null - type: ImageTransformer - args: - encoder: - type: null - args: null - num_decoder_layers: 4 - hidden_dim: 256 - num_heads: 4 - expansion_dim: 1024 - dropout_rate: 0.1 - transformer_activation: glu - -model: - desc: null - type: LitTransformerModel - args: - optimizer: - type: MADGRAD - args: - lr: 1.0e-2 - momentum: 0.9 - weight_decay: 0 - eps: 1.0e-6 - lr_scheduler: - type: CosineAnnealingLR - args: - T_max: 512 - criterion: - type: CrossEntropyLoss - args: - weight: None - ignore_index: -100 - reduction: mean - monitor: val_loss - mapping: sentence_piece - -data: - desc: null - type: IAMExtendedParagraphs - args: - batch_size: 16 - num_workers: 12 - train_fraction: 0.8 - augment: true - -callbacks: - - type: ModelCheckpoint - args: - monitor: val_loss - mode: min - - type: EarlyStopping - args: - monitor: val_loss - mode: min - patience: 10 - -trainer: - desc: null - args: - stochastic_weight_avg: true - auto_scale_batch_size: binsearch - gradient_clip_val: 0 - fast_dev_run: false - gpus: 1 - precision: 16 - max_epochs: 512 - terminate_on_nan: true - weights_summary: true diff --git a/training/run_experiment.py b/training/run_experiment.py deleted file mode 100644 index 491112c..0000000 --- a/training/run_experiment.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Script to run experiments.""" -from datetime import datetime -import importlib -from pathlib import Path -from typing import Dict, List, NamedTuple, Optional, Union, Type - -import click -from loguru import logger -from omegaconf import DictConfig, OmegaConf -import pytorch_lightning as pl -import torch -from torch import nn -from torchsummary import summary -from tqdm import tqdm -import wandb - - -SEED = 4711 -EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - - -def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: - """Configure the loguru logger for output to terminal and disk.""" - - def _get_level(verbose: int) -> str: - """Sets the logger level.""" - levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"} - verbose = min(verbose, 2) - return levels[verbose] - - # Have to remove default logger to get tqdm to work properly. - logger.remove() - - # Fetch verbosity level. - level = _get_level(verbose) - - logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) - if log_dir is not None: - logger.add( - str(log_dir / "train.log"), - format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", - ) - - -def _load_config(file_path: Path) -> DictConfig: - """Return experiment config.""" - logger.info(f"Loading config from: {file_path}") - if not file_path.exists(): - raise FileNotFoundError(f"Experiment config not found at: {file_path}") - return OmegaConf.load(file_path) - - -def _import_class(module_and_class_name: str) -> type: - """Import class from module.""" - module_name, class_name = module_and_class_name.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, class_name) - - -def _configure_callbacks( - callbacks: List[DictConfig], -) -> List[Type[pl.callbacks.Callback]]: - """Configures lightning callbacks.""" - pl_callbacks = [ - getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks - ] - return pl_callbacks - - -def _configure_logger( - network: Type[nn.Module], args: Dict, use_wandb: bool -) -> Type[pl.loggers.LightningLoggerBase]: - """Configures lightning logger.""" - if use_wandb: - pl_logger = pl.loggers.WandbLogger() - pl_logger.watch(network) - pl_logger.log_hyperparams(vars(args)) - return pl_logger - return pl.logger.TensorBoardLogger("training/logs") - - -def _save_best_weights( - callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool -) -> None: - """Saves the best model.""" - model_checkpoint_callback = next( - callback - for callback in callbacks - if isinstance(callback, pl.callbacks.ModelCheckpoint) - ) - best_model_path = model_checkpoint_callback.best_model_path - if best_model_path: - logger.info(f"Best model saved at: {best_model_path}") - if use_wandb: - logger.info("Uploading model to W&B...") - wandb.save(best_model_path) - - -def _load_lit_model( - lit_model_class: type, network: Type[nn.Module], config: DictConfig -) -> Type[pl.LightningModule]: - """Load lightning model.""" - if config.load_checkpoint is not None: - logger.info( - f"Loading network weights from checkpoint: {config.load_checkpoint}" - ) - return lit_model_class.load_from_checkpoint( - config.load_checkpoint, network=network, **config.model.args - ) - return lit_model_class(network=network, **config.model.args) - - -def run( - filename: str, - train: bool, - test: bool, - tune: bool, - use_wandb: bool, - verbose: int = 0, -) -> None: - """Runs experiment.""" - - _configure_logging(None, verbose=verbose) - logger.info("Starting experiment...") - - # Seed everything in the experiment. - logger.info(f"Seeding everthing with seed={SEED}") - pl.utilities.seed.seed_everything(SEED) - - # Load config. - file_path = EXPERIMENTS_DIRNAME / filename - config = _load_config(file_path) - - # Load classes. - data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") - network_class = _import_class(f"text_recognizer.networks.{config.network.type}") - lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") - - # Initialize data object and network. - data_module = data_module_class(**config.data.args) - network = network_class(**data_module.config(), **config.network.args) - - # Load callback and logger. - callbacks = _configure_callbacks(config.callbacks) - pl_logger = _configure_logger(network, config, use_wandb) - - # Load ligtning model. - lit_model = _load_lit_model(lit_model_class, network, config) - - trainer = pl.Trainer( - **config.trainer.args, - callbacks=callbacks, - logger=pl_logger, - weigths_save_path="training/logs", - ) - - if tune: - logger.info(f"Tuning learning rate and batch size...") - trainer.tune(lit_model, datamodule=data_module) - - if train: - logger.info(f"Training network...") - trainer.fit(lit_model, datamodule=data_module) - - if test: - logger.info(f"Testing network...") - trainer.test(lit_model, datamodule=data_module) - - _save_best_weights(callbacks, use_wandb) - - -@click.command() -@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") -@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") -@click.option( - "--tune", is_flag=True, help="If true, tune hyperparameters for training." -) -@click.option("--train", is_flag=True, help="If true, train the model.") -@click.option("--test", is_flag=True, help="If true, test the model.") -@click.option("-v", "--verbose", count=True) -def cli( - experiment_config: str, - use_wandb: bool, - tune: bool, - train: bool, - test: bool, - verbose: int, -) -> None: - """Run experiment.""" - run( - filename=experiment_config, - train=train, - test=test, - tune=tune, - use_wandb=use_wandb, - verbose=verbose, - ) - - -if __name__ == "__main__": - cli() -- cgit v1.2.3-70-g09d2