From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- text_recognizer/data/base_data_module.py | 14 +++++++------- text_recognizer/data/base_dataset.py | 11 +++++------ text_recognizer/data/download_utils.py | 8 ++++---- text_recognizer/data/emnist.py | 21 ++++++++++----------- text_recognizer/data/emnist_lines.py | 21 ++++++++++----------- text_recognizer/data/iam.py | 4 ++-- text_recognizer/data/iam_extended_paragraphs.py | 6 +++--- text_recognizer/data/iam_lines.py | 21 ++++++++------------- text_recognizer/data/iam_paragraphs.py | 18 ++++++------------ text_recognizer/data/iam_preprocessor.py | 16 +++++++--------- text_recognizer/data/iam_synthetic_paragraphs.py | 19 ++++++++++--------- text_recognizer/data/make_wordpieces.py | 8 ++++---- text_recognizer/data/mappings.py | 24 +++++++++++++++++------- 13 files changed, 93 insertions(+), 98 deletions(-) (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 408ae36..fd914b6 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,12 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.base_dataset import BaseDataset @@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule): def __attrs_pre_init__(self) -> None: super().__init__() + mapping: AbstractMapping = attr.ib() batch_size: int = attr.ib(default=16) num_workers: int = attr.ib(default=0) + pin_memory: bool = attr.ib(default=True) # Placeholders data_train: BaseDataset = attr.ib(init=False, default=None) @@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule): data_test: BaseDataset = attr.ib(init=False, default=None) dims: Tuple[int, ...] = attr.ib(init=False, default=None) output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) - mapping: Any = attr.ib(init=False, default=None) - inverse_mapping: Dict[str, int] = attr.ib(init=False) @classmethod def data_dirname(cls) -> Path: @@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule): return { "input_dim": self.dims, "output_dims": self.output_dims, - "mapping": self.mapping, } def prepare_data(self) -> None: @@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule): shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def val_dataloader(self) -> DataLoader: @@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def test_dataloader(self) -> DataLoader: @@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c26f1c9..8640d92 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,5 +1,5 @@ """Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import attr import torch @@ -22,14 +22,13 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Callable = attr.ib() - target_transform: Callable = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: super().__init__() def __attrs_post_init__(self) -> None: - # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -37,14 +36,14 @@ class BaseDataset(Dataset): """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Any, Any]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Any, Any]: Datum and target pair. + Tuple[Tensor, Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py index e3dc68c..8938830 100644 --- a/text_recognizer/data/download_utils.py +++ b/text_recognizer/data/download_utils.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional from urllib.request import urlretrieve -from loguru import logger +from loguru import logger as log from tqdm import tqdm @@ -32,7 +32,7 @@ class TqdmUpTo(tqdm): total_size (Optional[int]): Total size in tqdm units. Defaults to None. """ if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init + self.total = total_size self.update(blocks * block_size - self.n) @@ -62,9 +62,9 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: filename = dl_dir / metadata["filename"] if filename.exists(): return - logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") + log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") _download_url(metadata["url"], filename) - logger.info("Computing the SHA-256...") + log.info("Computing the SHA-256...") sha256 = _compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError( diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 2d0ac29..c6be123 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,12 +3,12 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple import zipfile import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import toml import torchvision.transforms as T @@ -50,8 +50,7 @@ class EMNIST(BaseDataModule): transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: - self.mapping, self.inverse_mapping, input_shape = emnist_mapping() - self.dims = (1, *input_shape) + self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: """Downloads dataset if not present.""" @@ -106,7 +105,7 @@ class EMNIST(BaseDataModule): def emnist_mapping( - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Set[str]] = None, ) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): @@ -130,7 +129,7 @@ def download_and_process_emnist() -> None: def _process_raw_dataset(filename: str, dirname: Path) -> None: """Processes the raw EMNIST dataset.""" - logger.info("Unzipping EMNIST...") + log.info("Unzipping EMNIST...") curdir = os.getcwd() os.chdir(dirname) content = zipfile.ZipFile(filename, "r") @@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: from scipy.io import loadmat - logger.info("Loading training data from .mat file") + log.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = ( data["dataset"]["train"][0, 0]["images"][0, 0] @@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") + log.info("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") + log.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") @@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + log.info("Saving essential dataset parameters to text_recognizer/datasets...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(mapping.values()) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} @@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: with ESSENTIALS_FILENAME.open(mode="w") as f: json.dump(essentials, f) - logger.info("Cleaning up...") + log.info("Cleaning up...") shutil.rmtree("matlab") os.chdir(curdir) diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 7548ad5..5298726 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,11 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, Tuple +from typing import Callable, List, Tuple import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import torch from torchvision import transforms @@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule): emnist: EMNIST = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: - self.emnist = EMNIST() - self.mapping = self.emnist.mapping + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule): self._generate_data("test") def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") + log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: print(self.data_filename) with h5py.File(self.data_filename, "r") as f: @@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule): return basic + data def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") + log.info(f"EMNISTLines generating data for {split}...") sentence_generator = SentenceGenerator( self.max_length - 2 ) # Subtract by 2 because start/end token @@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule): if split == "train": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_train elif split == "val": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_val else: samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping + emnist.x_test, emnist.y_test, self.mapping.mapping ) num = self.num_test @@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict + samples: np.ndarray, labels: np.ndarray, mapping: List ) -> defaultdict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 3982c4f..7278eb2 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -7,7 +7,7 @@ import zipfile import attr from boltons.cacheutils import cachedproperty -from loguru import logger +from loguru import logger as log import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -92,7 +92,7 @@ class IAM(BaseDataModule): def _extract_raw_dataset(filename: Path, dirname: Path) -> None: - logger.info("Extracting IAM data...") + log.info("Extracting IAM data...") curdir = os.getcwd() os.chdir(dirname) with zipfile.ZipFile(filename, "r") as f: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0e97801..ccf0759 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -4,7 +4,6 @@ from typing import Dict, List import attr from torch.utils.data import ConcatDataset -from text_recognizer.data.base_dataset import BaseDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs @@ -20,6 +19,7 @@ class IAMExtendedParagraphs(BaseDataModule): def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( + mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -27,6 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule): word_pieces=self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( + mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -36,7 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule): self.dims = self.iam_paragraphs.dims self.output_dims = self.iam_paragraphs.output_dims - self.num_classes = self.iam_paragraphs.num_classes def prepare_data(self) -> None: """Prepares the paragraphs data.""" @@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule): """Returns info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member - f"Num classes: {len(self.num_classes)}\n" + f"Num classes: {len(self.mapping)}\n" f"Dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index b7f3fdd..1c63729 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -2,15 +2,14 @@ If not created, will generate a handwritten lines dataset from the IAM paragraphs dataset. - """ import json from pathlib import Path import random -from typing import Dict, List, Sequence, Tuple +from typing import List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log from PIL import Image, ImageFile, ImageOps import numpy as np from torch import Tensor @@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data import image_utils @@ -48,17 +47,13 @@ class IAMLines(BaseDataModule): ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - def __attrs_post_init__(self) -> None: - # TODO: refactor this - self.mapping, self.inverse_mapping, _ = emnist_mapping() - def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Cropping IAM lines regions...") - iam = IAM() + log.info("Cropping IAM lines regions...") + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") @@ -66,7 +61,7 @@ class IAMLines(BaseDataModule): shapes = np.array([crop.size for crop in crops_train + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] - logger.info("Saving images, labels, and statistics...") + log.info("Saving images, labels, and statistics...") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -91,7 +86,7 @@ class IAMLines(BaseDataModule): raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( - labels_train, self.inverse_mapping, length=self.output_dims[0] + labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) @@ -110,7 +105,7 @@ class IAMLines(BaseDataModule): raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( - labels_test, self.inverse_mapping, length=self.output_dims[0] + labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( x_test, y_test, transform=get_transform(IMAGE_WIDTH) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 0f3a2ce..6189f7d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image, ImageOps import torchvision.transforms as T @@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import WordPieceMapping from text_recognizer.data.transforms import WordPiece @@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682 class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - num_classes: int = attr.ib() word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) @@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule): init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN]) def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info( + log.info( "Cropping IAM paragraph regions and saving them along with labels..." ) - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() properties = {} @@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule): crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( - strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0] + strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0] ) return BaseDataset( data, @@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule): target_transform=get_target_transform(self.word_pieces), ) - logger.info(f"Loading IAM paragraph regions and lines for {stage}...") + log.info(f"Loading IAM paragraph regions and lines for {stage}...") _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims) if stage == "fit" or stage is None: diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 93a13bb..bcd77b4 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -1,18 +1,16 @@ """Preprocessor for extracting word letters from the IAM dataset. The code is mostly stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - """ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union, Sequence +from typing import List, Optional, Union, Set import click -from loguru import logger +from loguru import logger as log import torch @@ -57,7 +55,7 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, - special_tokens: Optional[Sequence[str]] = None, + special_tokens: Optional[Set[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words @@ -186,7 +184,7 @@ def cli( / "iam" / "iamdb" ) - logger.debug(f"Using data dir: {data_dir}") + log.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") else: @@ -196,15 +194,15 @@ def cli( preprocessor.extract_train_text() processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") + log.debug(f"Saving processed files at: {processed_dir}") if save_text is not None: - logger.info("Saving training text") + log.info("Saving training text") with open(processed_dir / save_text, "w") as f: f.write("\n".join(t for t in preprocessor.text)) if save_tokens is not None: - logger.info("Saving tokens") + log.info("Saving tokens") with open(processed_dir / save_tokens, "w") as f: f.write("\n".join(preprocessor.tokens)) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index f00a494..c938f8b 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -3,7 +3,7 @@ import random from typing import Any, List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image @@ -21,6 +21,7 @@ from text_recognizer.data.iam_paragraphs import ( IMAGE_SCALE_FACTOR, resize_image, ) +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -43,10 +44,10 @@ class IAMSyntheticParagraphs(IAMParagraphs): if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Preparing IAM lines for synthetic paragraphs dataset.") - logger.info("Cropping IAM line regions and loading labels.") + log.info("Preparing IAM lines for synthetic paragraphs dataset.") + log.info("Cropping IAM line regions and loading labels.") - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") @@ -55,7 +56,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train] crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test] - logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") + log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -64,7 +65,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" - logger.info(f"IAM Synthetic dataset steup for stage {stage}...") + log.info(f"IAM Synthetic dataset steup for stage {stage}...") if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( @@ -76,7 +77,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): targets = convert_strings_to_labels( strings=paragraphs_labels, - mapping=self.inverse_mapping, + mapping=self.mapping.inverse_mapping, length=self.output_dims[0], ) self.data_train = BaseDataset( @@ -144,7 +145,7 @@ def generate_synthetic_paragraphs( [line_labels[i] for i in paragraph_indices] ) if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: - logger.info( + log.info( "Label longer than longest label in original IAM paragraph dataset - hence dropping." ) continue @@ -158,7 +159,7 @@ def generate_synthetic_paragraphs( paragraph_crop.height > max_paragraph_shape[0] or paragraph_crop.width > max_paragraph_shape[1] ): - logger.info( + log.info( "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping" ) continue diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py index ef9eb1b..40fbee4 100644 --- a/text_recognizer/data/make_wordpieces.py +++ b/text_recognizer/data/make_wordpieces.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import List, Optional, Union import click -from loguru import logger +from loguru import logger as log import sentencepiece as spm from text_recognizer.data.iam_preprocessor import load_metadata @@ -63,9 +63,9 @@ def save_pieces( vocab: set, ) -> None: """Saves word pieces to disk.""" - logger.info(f"Generating word piece list of size {num_pieces}.") + log.info(f"Generating word piece list of size {num_pieces}.") pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)] - logger.info(f"Encoding vocabulary of size {len(vocab)}.") + log.info(f"Encoding vocabulary of size {len(vocab)}.") encoded_vocab = [sp.encode_as_pieces(v) for v in vocab] # Save pieces to file. @@ -101,7 +101,7 @@ def cli( data_dir = ( Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) - logger.debug(f"Using data dir: {data_dir}") + log.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") else: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index b69e888..d1c64dd 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,18 +1,30 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union, Set, Sequence +from typing import Dict, List, Optional, Union, Set import attr -import loguru.logger as log import torch +from loguru import logger as log from torch import Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam_preprocessor import Preprocessor +@attr.s class AbstractMapping(ABC): + input_size: List[int] = attr.ib(init=False) + mapping: List[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + + def __len__(self) -> int: + return len(self.mapping) + + @property + def num_classes(self) -> int: + return self.__len__() + @abstractmethod def get_token(self, *args, **kwargs) -> str: ... @@ -30,15 +42,13 @@ class AbstractMapping(ABC): ... -@attr.s +@attr.s(auto_attribs=True) class EmnistMapping(AbstractMapping): - extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) - mapping: Sequence[str] = attr.ib(init=False) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - input_size: List[int] = attr.ib(init=False) + extra_symbols: Optional[Set[str]] = attr.ib(default=None) def __attrs_post_init__(self) -> None: """Post init configuration.""" + self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( self.extra_symbols ) -- cgit v1.2.3-70-g09d2