diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-08 23:38:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-08 23:38:03 +0200 |
commit | e388cd95c77d37a51324cff9d84a809421bf97d3 (patch) | |
tree | d585545f85d03ea8a6907daba254821fddeb1589 | |
parent | f4629a0d4149d5870c9fd8ce83ff5d391bd7ddd3 (diff) |
Bug fixes word pieces
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | notebooks/03-look-at-iam-paragraphs.ipynb | 226 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 28 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 6 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 6 | ||||
-rw-r--r-- | text_recognizer/training/experiments/image_transformer.yaml (renamed from training/experiments/image_transformer.yaml) | 0 | ||||
-rw-r--r-- | text_recognizer/training/run_experiment.py (renamed from training/run_experiment.py) | 4 |
9 files changed, 244 insertions, 34 deletions
@@ -31,10 +31,11 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb - [x] build_transitions.py - [x] transform that encodes iam targets to wordpieces - [x] transducer loss function -- [ ] Train with word pieces +- [ ] Train with word pieces + - [ ] Pad word pieces index to same length - [ ] Local attention in first layer of transformer - [ ] Halonet encoder -- [ ] Implement CPC +- [ ] Implement CPC - [ ] https://arxiv.org/pdf/1905.09272.pdf - [ ] https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html?highlight=byol diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index df92f99..4b82034 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "id": "6ce2519f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", @@ -39,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 162, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -56,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "42501428", "metadata": {}, "outputs": [ @@ -64,7 +55,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-04-03 21:55:37.196 | INFO | text_recognizer.data.iam_paragraphs:setup:104 - Loading IAM paragraph regions and lines for None...\n" + "2021-04-08 21:48:18.431 | INFO | text_recognizer.data.iam_paragraphs:setup:106 - Loading IAM paragraph regions and lines for None...\n" ] }, { @@ -76,7 +67,7 @@ "Input dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", "Train/val/test sizes: 1046, 262, 231\n", - "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0358), tensor(0.1021), tensor(1.))\n", + "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0371), tensor(0.1049), tensor(1.))\n", "Train Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n", "Test Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0284), tensor(0.0846), tensor(0.9373))\n", "Test Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n", @@ -93,6 +84,211 @@ }, { "cell_type": "code", + "execution_count": 163, + "id": "0cf22683", + "metadata": {}, + "outputs": [], + "source": [ + "x, y = dataset.data_train[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "id": "98dd0ee6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1, 33, 47, 44, 66, 51, 40, 59, 59, 44, 57, 66, 43, 54, 66, 53, 54, 59,\n", + " 66, 57, 44, 46, 40, 57, 43, 66, 59, 47, 44, 52, 58, 44, 51, 61, 44, 58,\n", + " 66, 40, 58, 66, 44, 63, 55, 44, 57, 59, 83, 40, 43, 61, 48, 58, 44, 57,\n", + " 58, 76, 66, 41, 60, 59, 66, 40, 57, 44, 66, 55, 57, 44, 55, 40, 57, 44,\n", + " 43, 66, 59, 54, 66, 58, 44, 44, 50, 66, 54, 60, 59, 66, 59, 47, 44, 83,\n", + " 40, 55, 55, 57, 54, 55, 57, 48, 40, 59, 44, 66, 58, 54, 60, 57, 42, 44,\n", + " 58, 66, 54, 45, 66, 48, 53, 45, 54, 57, 52, 40, 59, 48, 54, 53, 66, 54,\n", + " 57, 66, 40, 43, 61, 48, 42, 44, 78, 83, 33, 54, 62, 40, 57, 43, 58, 66,\n", + " 59, 47, 44, 66, 44, 53, 43, 66, 54, 45, 66, 5, 13, 9, 10, 76, 66, 26,\n", + " 57, 78, 66, 17, 40, 53, 48, 44, 51, 66, 20, 57, 40, 53, 59, 76, 83, 40,\n", + " 53, 66, 18, 52, 55, 51, 54, 64, 44, 44, 66, 31, 44, 51, 40, 59, 48, 54,\n", + " 53, 58, 66, 28, 45, 45, 48, 42, 44, 57, 66, 54, 45, 83, 31, 54, 51, 51,\n", + " 58, 77, 31, 54, 64, 42, 44, 66, 25, 59, 43, 78, 66, 40, 53, 43, 66, 40,\n", + " 66, 52, 44, 52, 41, 44, 57, 66, 54, 45, 66, 59, 47, 44, 83, 36, 54, 57,\n", + " 50, 44, 57, 58, 71, 66, 18, 43, 60, 42, 40, 59, 48, 54, 53, 40, 51, 66,\n", + " 14, 58, 58, 54, 42, 48, 40, 59, 48, 54, 53, 76, 66, 58, 60, 41, 52, 48,\n", + " 59, 59, 44, 43, 83, 59, 54, 66, 59, 47, 44, 66, 25, 54, 57, 43, 66, 29,\n", + " 57, 54, 61, 54, 58, 59, 66, 54, 45, 66, 20, 51, 40, 58, 46, 54, 62, 76,\n", + " 66, 17, 57, 78, 66, 14, 53, 43, 57, 44, 62, 83, 21, 54, 54, 43, 76, 66,\n", + " 40, 66, 42, 54, 55, 64, 66, 54, 45, 66, 47, 48, 58, 66, 57, 44, 55, 54,\n", + " 57, 59, 66, 54, 53, 66, 40, 53, 66, 44, 53, 56, 60, 48, 57, 64, 83, 47,\n", + " 44, 66, 47, 40, 43, 66, 52, 40, 43, 44, 66, 48, 53, 59, 54, 66, 59, 47,\n", + " 44, 66, 55, 57, 54, 41, 51, 44, 52, 58, 66, 59, 47, 40, 59, 66, 41, 44,\n", + " 58, 44, 59, 83, 54, 51, 43, 44, 57, 66, 62, 54, 57, 50, 44, 57, 58, 66,\n", + " 40, 53, 43, 66, 59, 47, 44, 66, 44, 45, 45, 44, 42, 59, 58, 66, 54, 45,\n", + " 66, 57, 44, 59, 48, 57, 44, 52, 44, 53, 59, 83, 21, 60, 46, 44, 53, 59,\n", + " 54, 41, 51, 44, 57, 66, 31, 54, 46, 44, 57, 2, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])" + ] + }, + "execution_count": 164, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "45649194", + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.data.iam_preprocessor import Preprocessor\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "0fc13f9f", + "metadata": {}, + "outputs": [], + "source": [ + "processor = Preprocessor(\n", + " data_dir=Path(\"../data/downloaded/iam/iamdb\"),\n", + " num_features=1000,\n", + " lexicon_path=Path(\"../data/processed/iam_lines/iamdb_1kwp_lex_1000.txt\"),\n", + " tokens_path=Path(\"../data/processed/iam_lines/iamdb_1kwp_tokens_1000.txt\"),\n", + " use_words=True,\n", + " prepend_wordsep=False,\n", + " special_tokens=[\"<s>\", \"<e>\", \"<p>\", \"\\n\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "id": "d08a0259", + "metadata": {}, + "outputs": [], + "source": [ + "t = convert_y_label_to_string(y, dataset.mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "a16a2cb7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"<s>The latter do not regard themselves as expert\\nadvisers, but are prepared to seek out the\\nappropriate sources of information or advice.\\nTowards the end of 1956, Mr. Daniel Grant,\\nan Employee Relations Officer of\\nRolls-Royce Ltd. and a member of the\\nWorkers' Educational Association, submitted\\nto the Lord Provost of Glasgow, Dr. Andrew\\nHood, a copy of his report on an enquiry\\nhe had made into the problems that beset\\nolder workers and the effects of retirement\\nHugentobler Roger<e>\"" + ] + }, + "execution_count": 168, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "id": "c7a33b2d", + "metadata": {}, + "outputs": [], + "source": [ + "ii = processor.to_index(t.replace(\" \", \"▁\").lower())" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "id": "4e0a22f4", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([247])" + ] + }, + "execution_count": 171, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ii.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "id": "bc1c5ffb", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([247])" + ] + }, + "execution_count": 172, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ii.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "id": "8b7b0373", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"<s>▁the▁latter▁do▁not▁regard▁themselves▁as▁expert\\n▁advisers,▁but▁are▁prepared▁to▁seek▁out▁the\\n▁appropriate▁sources▁of▁information▁or▁advice.\\n▁towards▁the▁end▁of▁1956,▁mr.▁daniel▁grant,\\n▁an▁employee▁relations▁officer▁of\\n▁rolls-royce▁ltd.▁and▁a▁member▁of▁the\\n▁workers'▁educational▁association,▁submitted\\n▁to▁the▁lord▁provost▁of▁glasgow,▁dr.▁andrew\\n▁hood,▁a▁copy▁of▁his▁report▁on▁an▁enquiry\\n▁he▁had▁made▁into▁the▁problems▁that▁beset\\n▁older▁workers▁and▁the▁effects▁of▁retirement\\n▁hugentobler▁roger<e>\"" + ] + }, + "execution_count": 176, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "processor.to_text(ii)" + ] + }, + { + "cell_type": "code", "execution_count": 4, "id": "e7778ae2", "metadata": { diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index d00daaf..8d644d4 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -67,7 +67,7 @@ def convert_strings_to_labels( labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"] for i, string in enumerate(strings): tokens = list(string) - tokens = ["<s>", *tokens, "</s>"] + tokens = ["<s>", *tokens, "<e>"] for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 01272ba..261c8d3 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -7,7 +7,6 @@ import zipfile from boltons.cacheutils import cachedproperty from loguru import logger -from PIL import Image import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 3844419..d85787e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -47,8 +47,6 @@ def load_metadata( class Preprocessor: """A preprocessor for the IAM dataset.""" - # TODO: add lower case only to when generating... - def __init__( self, data_dir: Union[str, Path], @@ -57,10 +55,12 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, + special_tokens: Optional[List[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep + self.special_tokens = special_tokens if special_tokens is not None else None self.data_dir = Path(data_dir) @@ -88,6 +88,10 @@ class Preprocessor: else: self.lexicon = None + if self.special_tokens is not None: + self.tokens += self.special_tokens + self.graphemes += self.special_tokens + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} self.num_features = num_features @@ -115,21 +119,31 @@ class Preprocessor: continue self.text.append(example["text"].lower()) - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" + + def _to_index(self, line: str) -> torch.LongTensor: + if line in self.special_tokens: + return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: if len(line) > 0: # If the word is not found in the lexicon, fall back to letters. - line = [ + tokens = [ t for w in line.split(self.wordsep) for t in self.lexicon.get(w, self.wordsep + w) ] token_to_index = self.tokens_to_index if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) + tokens = itertools.chain([self.wordsep], tokens) + return torch.LongTensor([token_to_index[t] for t in tokens]) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + if self.special_tokens is not None: + pattern = f"({'|'.join(self.special_tokens)})" + lines = list(filter(None, re.split(pattern, line))) + return torch.cat([self._to_index(l) for l in lines]) + return self._to_index(line) def to_text(self, indices: List[int]) -> str: """Converts indices to text.""" diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 616e236..297c953 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -23,12 +23,12 @@ class ToLower: class ToCharcters: """Converts integers to characters.""" - def __init__(self) -> None: - self.mapping, _, _ = emnist_mapping() + def __init__(self, extra_symbols: Optional[List[str]] = None) -> None: + self.mapping, _, _ = emnist_mapping(extra_symbols) def __call__(self, y: Tensor) -> str: """Converts a Tensor to a str.""" - return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁") + return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁") class WordPieces: diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 3c1919e..0928e6c 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class LitBaseModel(pl.LightningModule): optimizer_class = getattr(torch.optim, self._optimizer.type) return optimizer_class(params=self.parameters(), **args) - def _configure_lr_scheduler(self) -> Dict[str, Any]: + def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]: """Configures the lr scheduler.""" scheduler = {"monitor": self.monitor} args = {} or self._lr_scheduler.args @@ -59,13 +59,13 @@ class LitBaseModel(pl.LightningModule): scheduler["scheduler"] = getattr( torch.optim.lr_scheduler, self._lr_scheduler.type - )(**args) + )(optimizer, **args) return scheduler def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler() + scheduler = self._configure_lr_scheduler(optimizer) return [optimizer], [scheduler] diff --git a/training/experiments/image_transformer.yaml b/text_recognizer/training/experiments/image_transformer.yaml index bedcbb5..bedcbb5 100644 --- a/training/experiments/image_transformer.yaml +++ b/text_recognizer/training/experiments/image_transformer.yaml diff --git a/training/run_experiment.py b/text_recognizer/training/run_experiment.py index 491112c..ed1a947 100644 --- a/training/run_experiment.py +++ b/text_recognizer/training/run_experiment.py @@ -2,7 +2,7 @@ from datetime import datetime import importlib from pathlib import Path -from typing import Dict, List, NamedTuple, Optional, Union, Type +from typing import Dict, List, Optional, Type import click from loguru import logger @@ -28,7 +28,7 @@ def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: verbose = min(verbose, 2) return levels[verbose] - # Have to remove default logger to get tqdm to work properly. + # Remove default logger to get tqdm to work properly. logger.remove() # Fetch verbosity level. |