summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-08 23:38:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-08 23:38:03 +0200
commite388cd95c77d37a51324cff9d84a809421bf97d3 (patch)
treed585545f85d03ea8a6907daba254821fddeb1589
parentf4629a0d4149d5870c9fd8ce83ff5d391bd7ddd3 (diff)
Bug fixes word pieces
-rw-r--r--README.md5
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb226
-rw-r--r--text_recognizer/data/base_dataset.py2
-rw-r--r--text_recognizer/data/iam.py1
-rw-r--r--text_recognizer/data/iam_preprocessor.py28
-rw-r--r--text_recognizer/data/transforms.py6
-rw-r--r--text_recognizer/models/base.py6
-rw-r--r--text_recognizer/training/experiments/image_transformer.yaml (renamed from training/experiments/image_transformer.yaml)0
-rw-r--r--text_recognizer/training/run_experiment.py (renamed from training/run_experiment.py)4
9 files changed, 244 insertions, 34 deletions
diff --git a/README.md b/README.md
index c30ee03..dac7e98 100644
--- a/README.md
+++ b/README.md
@@ -31,10 +31,11 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb
- [x] build_transitions.py
- [x] transform that encodes iam targets to wordpieces
- [x] transducer loss function
-- [ ] Train with word pieces
+- [ ] Train with word pieces
+ - [ ] Pad word pieces index to same length
- [ ] Local attention in first layer of transformer
- [ ] Halonet encoder
-- [ ] Implement CPC
+- [ ] Implement CPC
- [ ] https://arxiv.org/pdf/1905.09272.pdf
- [ ] https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html?highlight=byol
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index df92f99..4b82034 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -2,19 +2,10 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 1,
"id": "6ce2519f",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
@@ -39,7 +30,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 162,
"id": "726ac25b",
"metadata": {},
"outputs": [],
@@ -56,7 +47,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "42501428",
"metadata": {},
"outputs": [
@@ -64,7 +55,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-04-03 21:55:37.196 | INFO | text_recognizer.data.iam_paragraphs:setup:104 - Loading IAM paragraph regions and lines for None...\n"
+ "2021-04-08 21:48:18.431 | INFO | text_recognizer.data.iam_paragraphs:setup:106 - Loading IAM paragraph regions and lines for None...\n"
]
},
{
@@ -76,7 +67,7 @@
"Input dims: (1, 576, 640)\n",
"Output dims: (682, 1)\n",
"Train/val/test sizes: 1046, 262, 231\n",
- "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0358), tensor(0.1021), tensor(1.))\n",
+ "Train Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0371), tensor(0.1049), tensor(1.))\n",
"Train Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n",
"Test Batch x stats: (torch.Size([128, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0284), tensor(0.0846), tensor(0.9373))\n",
"Test Batch y stats: (torch.Size([128, 682]), torch.int64, tensor(1), tensor(83))\n",
@@ -93,6 +84,211 @@
},
{
"cell_type": "code",
+ "execution_count": 163,
+ "id": "0cf22683",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, y = dataset.data_train[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 164,
+ "id": "98dd0ee6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 1, 33, 47, 44, 66, 51, 40, 59, 59, 44, 57, 66, 43, 54, 66, 53, 54, 59,\n",
+ " 66, 57, 44, 46, 40, 57, 43, 66, 59, 47, 44, 52, 58, 44, 51, 61, 44, 58,\n",
+ " 66, 40, 58, 66, 44, 63, 55, 44, 57, 59, 83, 40, 43, 61, 48, 58, 44, 57,\n",
+ " 58, 76, 66, 41, 60, 59, 66, 40, 57, 44, 66, 55, 57, 44, 55, 40, 57, 44,\n",
+ " 43, 66, 59, 54, 66, 58, 44, 44, 50, 66, 54, 60, 59, 66, 59, 47, 44, 83,\n",
+ " 40, 55, 55, 57, 54, 55, 57, 48, 40, 59, 44, 66, 58, 54, 60, 57, 42, 44,\n",
+ " 58, 66, 54, 45, 66, 48, 53, 45, 54, 57, 52, 40, 59, 48, 54, 53, 66, 54,\n",
+ " 57, 66, 40, 43, 61, 48, 42, 44, 78, 83, 33, 54, 62, 40, 57, 43, 58, 66,\n",
+ " 59, 47, 44, 66, 44, 53, 43, 66, 54, 45, 66, 5, 13, 9, 10, 76, 66, 26,\n",
+ " 57, 78, 66, 17, 40, 53, 48, 44, 51, 66, 20, 57, 40, 53, 59, 76, 83, 40,\n",
+ " 53, 66, 18, 52, 55, 51, 54, 64, 44, 44, 66, 31, 44, 51, 40, 59, 48, 54,\n",
+ " 53, 58, 66, 28, 45, 45, 48, 42, 44, 57, 66, 54, 45, 83, 31, 54, 51, 51,\n",
+ " 58, 77, 31, 54, 64, 42, 44, 66, 25, 59, 43, 78, 66, 40, 53, 43, 66, 40,\n",
+ " 66, 52, 44, 52, 41, 44, 57, 66, 54, 45, 66, 59, 47, 44, 83, 36, 54, 57,\n",
+ " 50, 44, 57, 58, 71, 66, 18, 43, 60, 42, 40, 59, 48, 54, 53, 40, 51, 66,\n",
+ " 14, 58, 58, 54, 42, 48, 40, 59, 48, 54, 53, 76, 66, 58, 60, 41, 52, 48,\n",
+ " 59, 59, 44, 43, 83, 59, 54, 66, 59, 47, 44, 66, 25, 54, 57, 43, 66, 29,\n",
+ " 57, 54, 61, 54, 58, 59, 66, 54, 45, 66, 20, 51, 40, 58, 46, 54, 62, 76,\n",
+ " 66, 17, 57, 78, 66, 14, 53, 43, 57, 44, 62, 83, 21, 54, 54, 43, 76, 66,\n",
+ " 40, 66, 42, 54, 55, 64, 66, 54, 45, 66, 47, 48, 58, 66, 57, 44, 55, 54,\n",
+ " 57, 59, 66, 54, 53, 66, 40, 53, 66, 44, 53, 56, 60, 48, 57, 64, 83, 47,\n",
+ " 44, 66, 47, 40, 43, 66, 52, 40, 43, 44, 66, 48, 53, 59, 54, 66, 59, 47,\n",
+ " 44, 66, 55, 57, 54, 41, 51, 44, 52, 58, 66, 59, 47, 40, 59, 66, 41, 44,\n",
+ " 58, 44, 59, 83, 54, 51, 43, 44, 57, 66, 62, 54, 57, 50, 44, 57, 58, 66,\n",
+ " 40, 53, 43, 66, 59, 47, 44, 66, 44, 45, 45, 44, 42, 59, 58, 66, 54, 45,\n",
+ " 66, 57, 44, 59, 48, 57, 44, 52, 44, 53, 59, 83, 21, 60, 46, 44, 53, 59,\n",
+ " 54, 41, 51, 44, 57, 66, 31, 54, 46, 44, 57, 2, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])"
+ ]
+ },
+ "execution_count": 164,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 165,
+ "id": "45649194",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.data.iam_preprocessor import Preprocessor\n",
+ "from pathlib import Path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 166,
+ "id": "0fc13f9f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "processor = Preprocessor(\n",
+ " data_dir=Path(\"../data/downloaded/iam/iamdb\"),\n",
+ " num_features=1000,\n",
+ " lexicon_path=Path(\"../data/processed/iam_lines/iamdb_1kwp_lex_1000.txt\"),\n",
+ " tokens_path=Path(\"../data/processed/iam_lines/iamdb_1kwp_tokens_1000.txt\"),\n",
+ " use_words=True,\n",
+ " prepend_wordsep=False,\n",
+ " special_tokens=[\"<s>\", \"<e>\", \"<p>\", \"\\n\"]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 167,
+ "id": "d08a0259",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = convert_y_label_to_string(y, dataset.mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 168,
+ "id": "a16a2cb7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"<s>The latter do not regard themselves as expert\\nadvisers, but are prepared to seek out the\\nappropriate sources of information or advice.\\nTowards the end of 1956, Mr. Daniel Grant,\\nan Employee Relations Officer of\\nRolls-Royce Ltd. and a member of the\\nWorkers' Educational Association, submitted\\nto the Lord Provost of Glasgow, Dr. Andrew\\nHood, a copy of his report on an enquiry\\nhe had made into the problems that beset\\nolder workers and the effects of retirement\\nHugentobler Roger<e>\""
+ ]
+ },
+ "execution_count": 168,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 170,
+ "id": "c7a33b2d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ii = processor.to_index(t.replace(\" \", \"▁\").lower())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 171,
+ "id": "4e0a22f4",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([247])"
+ ]
+ },
+ "execution_count": 171,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ii.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 172,
+ "id": "bc1c5ffb",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([247])"
+ ]
+ },
+ "execution_count": 172,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ii.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 176,
+ "id": "8b7b0373",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"<s>▁the▁latter▁do▁not▁regard▁themselves▁as▁expert\\n▁advisers,▁but▁are▁prepared▁to▁seek▁out▁the\\n▁appropriate▁sources▁of▁information▁or▁advice.\\n▁towards▁the▁end▁of▁1956,▁mr.▁daniel▁grant,\\n▁an▁employee▁relations▁officer▁of\\n▁rolls-royce▁ltd.▁and▁a▁member▁of▁the\\n▁workers'▁educational▁association,▁submitted\\n▁to▁the▁lord▁provost▁of▁glasgow,▁dr.▁andrew\\n▁hood,▁a▁copy▁of▁his▁report▁on▁an▁enquiry\\n▁he▁had▁made▁into▁the▁problems▁that▁beset\\n▁older▁workers▁and▁the▁effects▁of▁retirement\\n▁hugentobler▁roger<e>\""
+ ]
+ },
+ "execution_count": 176,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "processor.to_text(ii)"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 4,
"id": "e7778ae2",
"metadata": {
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index d00daaf..8d644d4 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -67,7 +67,7 @@ def convert_strings_to_labels(
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
for i, string in enumerate(strings):
tokens = list(string)
- tokens = ["<s>", *tokens, "</s>"]
+ tokens = ["<s>", *tokens, "<e>"]
for j, token in enumerate(tokens):
labels[i, j] = mapping[token]
return labels
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 01272ba..261c8d3 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -7,7 +7,6 @@ import zipfile
from boltons.cacheutils import cachedproperty
from loguru import logger
-from PIL import Image
import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 3844419..d85787e 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -47,8 +47,6 @@ def load_metadata(
class Preprocessor:
"""A preprocessor for the IAM dataset."""
- # TODO: add lower case only to when generating...
-
def __init__(
self,
data_dir: Union[str, Path],
@@ -57,10 +55,12 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
+ special_tokens: Optional[List[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
+ self.special_tokens = special_tokens if special_tokens is not None else None
self.data_dir = Path(data_dir)
@@ -88,6 +88,10 @@ class Preprocessor:
else:
self.lexicon = None
+ if self.special_tokens is not None:
+ self.tokens += self.special_tokens
+ self.graphemes += self.special_tokens
+
self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
self.num_features = num_features
@@ -115,21 +119,31 @@ class Preprocessor:
continue
self.text.append(example["text"].lower())
- def to_index(self, line: str) -> torch.LongTensor:
- """Converts text to a tensor of indices."""
+
+ def _to_index(self, line: str) -> torch.LongTensor:
+ if line in self.special_tokens:
+ return torch.LongTensor([self.tokens_to_index[line]])
token_to_index = self.graphemes_to_index
if self.lexicon is not None:
if len(line) > 0:
# If the word is not found in the lexicon, fall back to letters.
- line = [
+ tokens = [
t
for w in line.split(self.wordsep)
for t in self.lexicon.get(w, self.wordsep + w)
]
token_to_index = self.tokens_to_index
if self._prepend_wordsep:
- line = itertools.chain([self.wordsep], line)
- return torch.LongTensor([token_to_index[t] for t in line])
+ tokens = itertools.chain([self.wordsep], tokens)
+ return torch.LongTensor([token_to_index[t] for t in tokens])
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ if self.special_tokens is not None:
+ pattern = f"({'|'.join(self.special_tokens)})"
+ lines = list(filter(None, re.split(pattern, line)))
+ return torch.cat([self._to_index(l) for l in lines])
+ return self._to_index(line)
def to_text(self, indices: List[int]) -> str:
"""Converts indices to text."""
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 616e236..297c953 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -23,12 +23,12 @@ class ToLower:
class ToCharcters:
"""Converts integers to characters."""
- def __init__(self) -> None:
- self.mapping, _, _ = emnist_mapping()
+ def __init__(self, extra_symbols: Optional[List[str]] = None) -> None:
+ self.mapping, _, _ = emnist_mapping(extra_symbols)
def __call__(self, y: Tensor) -> str:
"""Converts a Tensor to a str."""
- return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁")
+ return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁")
class WordPieces:
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 3c1919e..0928e6c 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -49,7 +49,7 @@ class LitBaseModel(pl.LightningModule):
optimizer_class = getattr(torch.optim, self._optimizer.type)
return optimizer_class(params=self.parameters(), **args)
- def _configure_lr_scheduler(self) -> Dict[str, Any]:
+ def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]:
"""Configures the lr scheduler."""
scheduler = {"monitor": self.monitor}
args = {} or self._lr_scheduler.args
@@ -59,13 +59,13 @@ class LitBaseModel(pl.LightningModule):
scheduler["scheduler"] = getattr(
torch.optim.lr_scheduler, self._lr_scheduler.type
- )(**args)
+ )(optimizer, **args)
return scheduler
def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
optimizer = self._configure_optimizer()
- scheduler = self._configure_lr_scheduler()
+ scheduler = self._configure_lr_scheduler(optimizer)
return [optimizer], [scheduler]
diff --git a/training/experiments/image_transformer.yaml b/text_recognizer/training/experiments/image_transformer.yaml
index bedcbb5..bedcbb5 100644
--- a/training/experiments/image_transformer.yaml
+++ b/text_recognizer/training/experiments/image_transformer.yaml
diff --git a/training/run_experiment.py b/text_recognizer/training/run_experiment.py
index 491112c..ed1a947 100644
--- a/training/run_experiment.py
+++ b/text_recognizer/training/run_experiment.py
@@ -2,7 +2,7 @@
from datetime import datetime
import importlib
from pathlib import Path
-from typing import Dict, List, NamedTuple, Optional, Union, Type
+from typing import Dict, List, Optional, Type
import click
from loguru import logger
@@ -28,7 +28,7 @@ def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
verbose = min(verbose, 2)
return levels[verbose]
- # Have to remove default logger to get tqdm to work properly.
+ # Remove default logger to get tqdm to work properly.
logger.remove()
# Fetch verbosity level.