summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /text_recognizer/data
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_data_module.py14
-rw-r--r--text_recognizer/data/base_dataset.py11
-rw-r--r--text_recognizer/data/download_utils.py8
-rw-r--r--text_recognizer/data/emnist.py21
-rw-r--r--text_recognizer/data/emnist_lines.py21
-rw-r--r--text_recognizer/data/iam.py4
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py6
-rw-r--r--text_recognizer/data/iam_lines.py21
-rw-r--r--text_recognizer/data/iam_paragraphs.py18
-rw-r--r--text_recognizer/data/iam_preprocessor.py16
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py19
-rw-r--r--text_recognizer/data/make_wordpieces.py8
-rw-r--r--text_recognizer/data/mappings.py24
13 files changed, 93 insertions, 98 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 408ae36..fd914b6 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ mapping: AbstractMapping = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
+ pin_memory: bool = attr.ib(default=True)
# Placeholders
data_train: BaseDataset = attr.ib(init=False, default=None)
@@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule):
data_test: BaseDataset = attr.ib(init=False, default=None)
dims: Tuple[int, ...] = attr.ib(init=False, default=None)
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- mapping: Any = attr.ib(init=False, default=None)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule):
return {
"input_dim": self.dims,
"output_dims": self.output_dims,
- "mapping": self.mapping,
}
def prepare_data(self) -> None:
@@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def val_dataloader(self) -> DataLoader:
@@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def test_dataloader(self) -> DataLoader:
@@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index c26f1c9..8640d92 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,5 +1,5 @@
"""Base PyTorch Dataset class."""
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
+from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import attr
import torch
@@ -22,14 +22,13 @@ class BaseDataset(Dataset):
data: Union[Sequence, Tensor] = attr.ib()
targets: Union[Sequence, Tensor] = attr.ib()
- transform: Callable = attr.ib()
- target_transform: Callable = attr.ib()
+ transform: Optional[Callable] = attr.ib(default=None)
+ target_transform: Optional[Callable] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
super().__init__()
def __attrs_post_init__(self) -> None:
- # TODO: refactor this
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
@@ -37,14 +36,14 @@ class BaseDataset(Dataset):
"""Return the length of the dataset."""
return len(self.data)
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
"""Return a datum and its target, after processing by transforms.
Args:
index (int): Index of a datum in the dataset.
Returns:
- Tuple[Any, Any]: Datum and target pair.
+ Tuple[Tensor, Tensor]: Datum and target pair.
"""
datum, target = self.data[index], self.targets[index]
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
index e3dc68c..8938830 100644
--- a/text_recognizer/data/download_utils.py
+++ b/text_recognizer/data/download_utils.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional
from urllib.request import urlretrieve
-from loguru import logger
+from loguru import logger as log
from tqdm import tqdm
@@ -32,7 +32,7 @@ class TqdmUpTo(tqdm):
total_size (Optional[int]): Total size in tqdm units. Defaults to None.
"""
if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.total = total_size
self.update(blocks * block_size - self.n)
@@ -62,9 +62,9 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
filename = dl_dir / metadata["filename"]
if filename.exists():
return
- logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
_download_url(metadata["url"], filename)
- logger.info("Computing the SHA-256...")
+ log.info("Computing the SHA-256...")
sha256 = _compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError(
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 2d0ac29..c6be123 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,12 +3,12 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple
import zipfile
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import toml
import torchvision.transforms as T
@@ -50,8 +50,7 @@ class EMNIST(BaseDataModule):
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
- self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
- self.dims = (1, *input_shape)
+ self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
@@ -106,7 +105,7 @@ class EMNIST(BaseDataModule):
def emnist_mapping(
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Set[str]] = None,
) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
@@ -130,7 +129,7 @@ def download_and_process_emnist() -> None:
def _process_raw_dataset(filename: str, dirname: Path) -> None:
"""Processes the raw EMNIST dataset."""
- logger.info("Unzipping EMNIST...")
+ log.info("Unzipping EMNIST...")
curdir = os.getcwd()
os.chdir(dirname)
content = zipfile.ZipFile(filename, "r")
@@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
from scipy.io import loadmat
- logger.info("Loading training data from .mat file")
+ log.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = (
data["dataset"]["train"][0, 0]["images"][0, 0]
@@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
- logger.info("Balancing classes to reduce amount of data")
+ log.info("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
- logger.info("Saving to HDF5 in a compressed format...")
+ log.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
@@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
- logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ log.info("Saving essential dataset parameters to text_recognizer/datasets...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(mapping.values())
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
@@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
with ESSENTIALS_FILENAME.open(mode="w") as f:
json.dump(essentials, f)
- logger.info("Cleaning up...")
+ log.info("Cleaning up...")
shutil.rmtree("matlab")
os.chdir(curdir)
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 7548ad5..5298726 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,11 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, Tuple
+from typing import Callable, List, Tuple
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import torch
from torchvision import transforms
@@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule):
emnist: EMNIST = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
- self.emnist = EMNIST()
- self.mapping = self.emnist.mapping
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
@@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule):
self._generate_data("test")
def setup(self, stage: str = None) -> None:
- logger.info("EMNISTLinesDataset loading data from HDF5...")
+ log.info("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
print(self.data_filename)
with h5py.File(self.data_filename, "r") as f:
@@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule):
return basic + data
def _generate_data(self, split: str) -> None:
- logger.info(f"EMNISTLines generating data for {split}...")
+ log.info(f"EMNISTLines generating data for {split}...")
sentence_generator = SentenceGenerator(
self.max_length - 2
) # Subtract by 2 because start/end token
@@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule):
if split == "train":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_train
elif split == "val":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_val
else:
samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, emnist.mapping
+ emnist.x_test, emnist.y_test, self.mapping.mapping
)
num = self.num_test
@@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def _get_samples_by_char(
- samples: np.ndarray, labels: np.ndarray, mapping: Dict
+ samples: np.ndarray, labels: np.ndarray, mapping: List
) -> defaultdict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 3982c4f..7278eb2 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -7,7 +7,7 @@ import zipfile
import attr
from boltons.cacheutils import cachedproperty
-from loguru import logger
+from loguru import logger as log
import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -92,7 +92,7 @@ class IAM(BaseDataModule):
def _extract_raw_dataset(filename: Path, dirname: Path) -> None:
- logger.info("Extracting IAM data...")
+ log.info("Extracting IAM data...")
curdir = os.getcwd()
os.chdir(dirname)
with zipfile.ZipFile(filename, "r") as f:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0e97801..ccf0759 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -4,7 +4,6 @@ from typing import Dict, List
import attr
from torch.utils.data import ConcatDataset
-from text_recognizer.data.base_dataset import BaseDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
@@ -20,6 +19,7 @@ class IAMExtendedParagraphs(BaseDataModule):
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
+ mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -27,6 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule):
word_pieces=self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
+ mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -36,7 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.dims = self.iam_paragraphs.dims
self.output_dims = self.iam_paragraphs.output_dims
- self.num_classes = self.iam_paragraphs.num_classes
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
@@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule):
"""Returns info about the dataset."""
basic = (
"IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
- f"Num classes: {len(self.num_classes)}\n"
+ f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index b7f3fdd..1c63729 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -2,15 +2,14 @@
If not created, will generate a handwritten lines dataset from the IAM paragraphs
dataset.
-
"""
import json
from pathlib import Path
import random
-from typing import Dict, List, Sequence, Tuple
+from typing import List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
@@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
@@ -48,17 +47,13 @@ class IAMLines(BaseDataModule):
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- def __attrs_post_init__(self) -> None:
- # TODO: refactor this
- self.mapping, self.inverse_mapping, _ = emnist_mapping()
-
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Cropping IAM lines regions...")
- iam = IAM()
+ log.info("Cropping IAM lines regions...")
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
@@ -66,7 +61,7 @@ class IAMLines(BaseDataModule):
shapes = np.array([crop.size for crop in crops_train + crops_test])
aspect_ratios = shapes[:, 0] / shapes[:, 1]
- logger.info("Saving images, labels, and statistics...")
+ log.info("Saving images, labels, and statistics...")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -91,7 +86,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
- labels_train, self.inverse_mapping, length=self.output_dims[0]
+ labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
@@ -110,7 +105,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
- labels_test, self.inverse_mapping, length=self.output_dims[0]
+ labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
x_test, y_test, transform=get_transform(IMAGE_WIDTH)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 0f3a2ce..6189f7d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
@@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings import WordPieceMapping
from text_recognizer.data.transforms import WordPiece
@@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- num_classes: int = attr.ib()
word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
@@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule):
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN])
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info(
+ log.info(
"Cropping IAM paragraph regions and saving them along with labels..."
)
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
properties = {}
@@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
+ strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
)
return BaseDataset(
data,
@@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule):
target_transform=get_target_transform(self.word_pieces),
)
- logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
+ log.info(f"Loading IAM paragraph regions and lines for {stage}...")
_validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
if stage == "fit" or stage is None:
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 93a13bb..bcd77b4 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -1,18 +1,16 @@
"""Preprocessor for extracting word letters from the IAM dataset.
The code is mostly stolen from:
-
https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
"""
import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union, Sequence
+from typing import List, Optional, Union, Set
import click
-from loguru import logger
+from loguru import logger as log
import torch
@@ -57,7 +55,7 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[Sequence[str]] = None,
+ special_tokens: Optional[Set[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
@@ -186,7 +184,7 @@ def cli(
/ "iam"
/ "iamdb"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
@@ -196,15 +194,15 @@ def cli(
preprocessor.extract_train_text()
processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
- logger.debug(f"Saving processed files at: {processed_dir}")
+ log.debug(f"Saving processed files at: {processed_dir}")
if save_text is not None:
- logger.info("Saving training text")
+ log.info("Saving training text")
with open(processed_dir / save_text, "w") as f:
f.write("\n".join(t for t in preprocessor.text))
if save_tokens is not None:
- logger.info("Saving tokens")
+ log.info("Saving tokens")
with open(processed_dir / save_tokens, "w") as f:
f.write("\n".join(preprocessor.tokens))
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index f00a494..c938f8b 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -3,7 +3,7 @@ import random
from typing import Any, List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image
@@ -21,6 +21,7 @@ from text_recognizer.data.iam_paragraphs import (
IMAGE_SCALE_FACTOR,
resize_image,
)
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -43,10 +44,10 @@ class IAMSyntheticParagraphs(IAMParagraphs):
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Preparing IAM lines for synthetic paragraphs dataset.")
- logger.info("Cropping IAM line regions and loading labels.")
+ log.info("Preparing IAM lines for synthetic paragraphs dataset.")
+ log.info("Cropping IAM line regions and loading labels.")
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
@@ -55,7 +56,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train]
crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test]
- logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
+ log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -64,7 +65,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def setup(self, stage: str = None) -> None:
"""Loading synthetic dataset."""
- logger.info(f"IAM Synthetic dataset steup for stage {stage}...")
+ log.info(f"IAM Synthetic dataset steup for stage {stage}...")
if stage == "fit" or stage is None:
line_crops, line_labels = load_line_crops_and_labels(
@@ -76,7 +77,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
targets = convert_strings_to_labels(
strings=paragraphs_labels,
- mapping=self.inverse_mapping,
+ mapping=self.mapping.inverse_mapping,
length=self.output_dims[0],
)
self.data_train = BaseDataset(
@@ -144,7 +145,7 @@ def generate_synthetic_paragraphs(
[line_labels[i] for i in paragraph_indices]
)
if len(paragraph_label) > paragraphs_properties["label_length"]["max"]:
- logger.info(
+ log.info(
"Label longer than longest label in original IAM paragraph dataset - hence dropping."
)
continue
@@ -158,7 +159,7 @@ def generate_synthetic_paragraphs(
paragraph_crop.height > max_paragraph_shape[0]
or paragraph_crop.width > max_paragraph_shape[1]
):
- logger.info(
+ log.info(
"Crop larger than largest crop in original IAM paragraphs dataset - hence dropping"
)
continue
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index ef9eb1b..40fbee4 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -10,7 +10,7 @@ from pathlib import Path
from typing import List, Optional, Union
import click
-from loguru import logger
+from loguru import logger as log
import sentencepiece as spm
from text_recognizer.data.iam_preprocessor import load_metadata
@@ -63,9 +63,9 @@ def save_pieces(
vocab: set,
) -> None:
"""Saves word pieces to disk."""
- logger.info(f"Generating word piece list of size {num_pieces}.")
+ log.info(f"Generating word piece list of size {num_pieces}.")
pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
- logger.info(f"Encoding vocabulary of size {len(vocab)}.")
+ log.info(f"Encoding vocabulary of size {len(vocab)}.")
encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
# Save pieces to file.
@@ -101,7 +101,7 @@ def cli(
data_dir = (
Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index b69e888..d1c64dd 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,18 +1,30 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Dict, List, Optional, Union, Set, Sequence
+from typing import Dict, List, Optional, Union, Set
import attr
-import loguru.logger as log
import torch
+from loguru import logger as log
from torch import Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam_preprocessor import Preprocessor
+@attr.s
class AbstractMapping(ABC):
+ input_size: List[int] = attr.ib(init=False)
+ mapping: List[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
@abstractmethod
def get_token(self, *args, **kwargs) -> str:
...
@@ -30,15 +42,13 @@ class AbstractMapping(ABC):
...
-@attr.s
+@attr.s(auto_attribs=True)
class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
- mapping: Sequence[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
- input_size: List[int] = attr.ib(init=False)
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
self.extra_symbols
)