diff options
Diffstat (limited to 'text_recognizer')
22 files changed, 180 insertions, 205 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 408ae36..fd914b6 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,12 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.base_dataset import BaseDataset @@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule): def __attrs_pre_init__(self) -> None: super().__init__() + mapping: AbstractMapping = attr.ib() batch_size: int = attr.ib(default=16) num_workers: int = attr.ib(default=0) + pin_memory: bool = attr.ib(default=True) # Placeholders data_train: BaseDataset = attr.ib(init=False, default=None) @@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule): data_test: BaseDataset = attr.ib(init=False, default=None) dims: Tuple[int, ...] = attr.ib(init=False, default=None) output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) - mapping: Any = attr.ib(init=False, default=None) - inverse_mapping: Dict[str, int] = attr.ib(init=False) @classmethod def data_dirname(cls) -> Path: @@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule): return { "input_dim": self.dims, "output_dims": self.output_dims, - "mapping": self.mapping, } def prepare_data(self) -> None: @@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule): shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def val_dataloader(self) -> DataLoader: @@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def test_dataloader(self) -> DataLoader: @@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c26f1c9..8640d92 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,5 +1,5 @@ """Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import attr import torch @@ -22,14 +22,13 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Callable = attr.ib() - target_transform: Callable = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: super().__init__() def __attrs_post_init__(self) -> None: - # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -37,14 +36,14 @@ class BaseDataset(Dataset): """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Any, Any]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Any, Any]: Datum and target pair. + Tuple[Tensor, Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py index e3dc68c..8938830 100644 --- a/text_recognizer/data/download_utils.py +++ b/text_recognizer/data/download_utils.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional from urllib.request import urlretrieve -from loguru import logger +from loguru import logger as log from tqdm import tqdm @@ -32,7 +32,7 @@ class TqdmUpTo(tqdm): total_size (Optional[int]): Total size in tqdm units. Defaults to None. """ if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init + self.total = total_size self.update(blocks * block_size - self.n) @@ -62,9 +62,9 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: filename = dl_dir / metadata["filename"] if filename.exists(): return - logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") + log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") _download_url(metadata["url"], filename) - logger.info("Computing the SHA-256...") + log.info("Computing the SHA-256...") sha256 = _compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError( diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 2d0ac29..c6be123 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,12 +3,12 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple import zipfile import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import toml import torchvision.transforms as T @@ -50,8 +50,7 @@ class EMNIST(BaseDataModule): transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: - self.mapping, self.inverse_mapping, input_shape = emnist_mapping() - self.dims = (1, *input_shape) + self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: """Downloads dataset if not present.""" @@ -106,7 +105,7 @@ class EMNIST(BaseDataModule): def emnist_mapping( - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Set[str]] = None, ) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): @@ -130,7 +129,7 @@ def download_and_process_emnist() -> None: def _process_raw_dataset(filename: str, dirname: Path) -> None: """Processes the raw EMNIST dataset.""" - logger.info("Unzipping EMNIST...") + log.info("Unzipping EMNIST...") curdir = os.getcwd() os.chdir(dirname) content = zipfile.ZipFile(filename, "r") @@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: from scipy.io import loadmat - logger.info("Loading training data from .mat file") + log.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = ( data["dataset"]["train"][0, 0]["images"][0, 0] @@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") + log.info("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") + log.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") @@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + log.info("Saving essential dataset parameters to text_recognizer/datasets...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(mapping.values()) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} @@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: with ESSENTIALS_FILENAME.open(mode="w") as f: json.dump(essentials, f) - logger.info("Cleaning up...") + log.info("Cleaning up...") shutil.rmtree("matlab") os.chdir(curdir) diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 7548ad5..5298726 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,11 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, Tuple +from typing import Callable, List, Tuple import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import torch from torchvision import transforms @@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule): emnist: EMNIST = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: - self.emnist = EMNIST() - self.mapping = self.emnist.mapping + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule): self._generate_data("test") def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") + log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: print(self.data_filename) with h5py.File(self.data_filename, "r") as f: @@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule): return basic + data def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") + log.info(f"EMNISTLines generating data for {split}...") sentence_generator = SentenceGenerator( self.max_length - 2 ) # Subtract by 2 because start/end token @@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule): if split == "train": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_train elif split == "val": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_val else: samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping + emnist.x_test, emnist.y_test, self.mapping.mapping ) num = self.num_test @@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict + samples: np.ndarray, labels: np.ndarray, mapping: List ) -> defaultdict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 3982c4f..7278eb2 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -7,7 +7,7 @@ import zipfile import attr from boltons.cacheutils import cachedproperty -from loguru import logger +from loguru import logger as log import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -92,7 +92,7 @@ class IAM(BaseDataModule): def _extract_raw_dataset(filename: Path, dirname: Path) -> None: - logger.info("Extracting IAM data...") + log.info("Extracting IAM data...") curdir = os.getcwd() os.chdir(dirname) with zipfile.ZipFile(filename, "r") as f: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0e97801..ccf0759 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -4,7 +4,6 @@ from typing import Dict, List import attr from torch.utils.data import ConcatDataset -from text_recognizer.data.base_dataset import BaseDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs @@ -20,6 +19,7 @@ class IAMExtendedParagraphs(BaseDataModule): def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( + mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -27,6 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule): word_pieces=self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( + mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -36,7 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule): self.dims = self.iam_paragraphs.dims self.output_dims = self.iam_paragraphs.output_dims - self.num_classes = self.iam_paragraphs.num_classes def prepare_data(self) -> None: """Prepares the paragraphs data.""" @@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule): """Returns info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member - f"Num classes: {len(self.num_classes)}\n" + f"Num classes: {len(self.mapping)}\n" f"Dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index b7f3fdd..1c63729 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -2,15 +2,14 @@ If not created, will generate a handwritten lines dataset from the IAM paragraphs dataset. - """ import json from pathlib import Path import random -from typing import Dict, List, Sequence, Tuple +from typing import List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log from PIL import Image, ImageFile, ImageOps import numpy as np from torch import Tensor @@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data import image_utils @@ -48,17 +47,13 @@ class IAMLines(BaseDataModule): ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - def __attrs_post_init__(self) -> None: - # TODO: refactor this - self.mapping, self.inverse_mapping, _ = emnist_mapping() - def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Cropping IAM lines regions...") - iam = IAM() + log.info("Cropping IAM lines regions...") + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") @@ -66,7 +61,7 @@ class IAMLines(BaseDataModule): shapes = np.array([crop.size for crop in crops_train + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] - logger.info("Saving images, labels, and statistics...") + log.info("Saving images, labels, and statistics...") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -91,7 +86,7 @@ class IAMLines(BaseDataModule): raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( - labels_train, self.inverse_mapping, length=self.output_dims[0] + labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) @@ -110,7 +105,7 @@ class IAMLines(BaseDataModule): raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( - labels_test, self.inverse_mapping, length=self.output_dims[0] + labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( x_test, y_test, transform=get_transform(IMAGE_WIDTH) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 0f3a2ce..6189f7d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image, ImageOps import torchvision.transforms as T @@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import WordPieceMapping from text_recognizer.data.transforms import WordPiece @@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682 class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - num_classes: int = attr.ib() word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) @@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule): init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN]) def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info( + log.info( "Cropping IAM paragraph regions and saving them along with labels..." ) - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() properties = {} @@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule): crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( - strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0] + strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0] ) return BaseDataset( data, @@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule): target_transform=get_target_transform(self.word_pieces), ) - logger.info(f"Loading IAM paragraph regions and lines for {stage}...") + log.info(f"Loading IAM paragraph regions and lines for {stage}...") _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims) if stage == "fit" or stage is None: diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 93a13bb..bcd77b4 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -1,18 +1,16 @@ """Preprocessor for extracting word letters from the IAM dataset. The code is mostly stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - """ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union, Sequence +from typing import List, Optional, Union, Set import click -from loguru import logger +from loguru import logger as log import torch @@ -57,7 +55,7 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, - special_tokens: Optional[Sequence[str]] = None, + special_tokens: Optional[Set[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words @@ -186,7 +184,7 @@ def cli( / "iam" / "iamdb" ) - logger.debug(f"Using data dir: {data_dir}") + log.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") else: @@ -196,15 +194,15 @@ def cli( preprocessor.extract_train_text() processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") + log.debug(f"Saving processed files at: {processed_dir}") if save_text is not None: - logger.info("Saving training text") + log.info("Saving training text") with open(processed_dir / save_text, "w") as f: f.write("\n".join(t for t in preprocessor.text)) if save_tokens is not None: - logger.info("Saving tokens") + log.info("Saving tokens") with open(processed_dir / save_tokens, "w") as f: f.write("\n".join(preprocessor.tokens)) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index f00a494..c938f8b 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -3,7 +3,7 @@ import random from typing import Any, List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image @@ -21,6 +21,7 @@ from text_recognizer.data.iam_paragraphs import ( IMAGE_SCALE_FACTOR, resize_image, ) +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -43,10 +44,10 @@ class IAMSyntheticParagraphs(IAMParagraphs): if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Preparing IAM lines for synthetic paragraphs dataset.") - logger.info("Cropping IAM line regions and loading labels.") + log.info("Preparing IAM lines for synthetic paragraphs dataset.") + log.info("Cropping IAM line regions and loading labels.") - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") @@ -55,7 +56,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train] crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test] - logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") + log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -64,7 +65,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" - logger.info(f"IAM Synthetic dataset steup for stage {stage}...") + log.info(f"IAM Synthetic dataset steup for stage {stage}...") if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( @@ -76,7 +77,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): targets = convert_strings_to_labels( strings=paragraphs_labels, - mapping=self.inverse_mapping, + mapping=self.mapping.inverse_mapping, length=self.output_dims[0], ) self.data_train = BaseDataset( @@ -144,7 +145,7 @@ def generate_synthetic_paragraphs( [line_labels[i] for i in paragraph_indices] ) if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: - logger.info( + log.info( "Label longer than longest label in original IAM paragraph dataset - hence dropping." ) continue @@ -158,7 +159,7 @@ def generate_synthetic_paragraphs( paragraph_crop.height > max_paragraph_shape[0] or paragraph_crop.width > max_paragraph_shape[1] ): - logger.info( + log.info( "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping" ) continue diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py index ef9eb1b..40fbee4 100644 --- a/text_recognizer/data/make_wordpieces.py +++ b/text_recognizer/data/make_wordpieces.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import List, Optional, Union import click -from loguru import logger +from loguru import logger as log import sentencepiece as spm from text_recognizer.data.iam_preprocessor import load_metadata @@ -63,9 +63,9 @@ def save_pieces( vocab: set, ) -> None: """Saves word pieces to disk.""" - logger.info(f"Generating word piece list of size {num_pieces}.") + log.info(f"Generating word piece list of size {num_pieces}.") pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)] - logger.info(f"Encoding vocabulary of size {len(vocab)}.") + log.info(f"Encoding vocabulary of size {len(vocab)}.") encoded_vocab = [sp.encode_as_pieces(v) for v in vocab] # Save pieces to file. @@ -101,7 +101,7 @@ def cli( data_dir = ( Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) - logger.debug(f"Using data dir: {data_dir}") + log.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") else: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index b69e888..d1c64dd 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,18 +1,30 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union, Set, Sequence +from typing import Dict, List, Optional, Union, Set import attr -import loguru.logger as log import torch +from loguru import logger as log from torch import Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam_preprocessor import Preprocessor +@attr.s class AbstractMapping(ABC): + input_size: List[int] = attr.ib(init=False) + mapping: List[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + + def __len__(self) -> int: + return len(self.mapping) + + @property + def num_classes(self) -> int: + return self.__len__() + @abstractmethod def get_token(self, *args, **kwargs) -> str: ... @@ -30,15 +42,13 @@ class AbstractMapping(ABC): ... -@attr.s +@attr.s(auto_attribs=True) class EmnistMapping(AbstractMapping): - extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) - mapping: Sequence[str] = attr.ib(init=False) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - input_size: List[int] = attr.ib(init=False) + extra_symbols: Optional[Set[str]] = attr.ib(default=None) def __attrs_post_init__(self) -> None: """Post init configuration.""" + self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( self.extra_symbols ) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index caf63c1..8ce5c37 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -12,7 +12,7 @@ from torch import Tensor import torchmetrics -@attr.s +@attr.s(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 0eb42dc..f83c9e4 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -8,11 +8,11 @@ from torch import Tensor from torchmetrics import Metric -@attr.s +@attr.s(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set = attr.ib(converter=set) + ignore_indices: Set[Tensor] = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 0e01bb5..91e088d 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Sequence, Tuple, Type +from typing import Tuple, Type, Set import attr import torch @@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib() - start_token: str = attr.ib() - end_token: str = attr.ib() - pad_token: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib(default=None) + start_token: str = attr.ib(default="<s>") + end_token: str = attr.ib(default="<e>") + pad_token: str = attr.ib(default="<p>") start_index: Tensor = attr.ib(init=False) end_index: Tensor = attr.ib(init=False) pad_index: Tensor = attr.ib(init=False) - ignore_indices: Sequence[str] = attr.ib(init=False) + ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) @@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel): self.start_index = self.mapping.get_index(self.start_token) self.end_index = self.mapping.get_index(self.end_token) self.pad_index = self.mapping.get_index(self.pad_token) - self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index e215e14..22da018 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -10,7 +10,7 @@ import wandb from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 7371be4..09cc654 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -13,7 +13,7 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s +@attr.s(eq=False) class ConvTransformer(nn.Module): """Convolutional encoder and transformer decoder network.""" @@ -121,6 +121,7 @@ class ConvTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ + context = context.long() context_mask = context != self.pad_index context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index a36150a..b8eb53b 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,4 @@ -"""Efficient net.""" +"""Efficientnet backbone.""" from typing import Tuple import attr @@ -12,8 +12,10 @@ from .utils import ( ) -@attr.s +@attr.s(eq=False) class EfficientNet(nn.Module): + """Efficientnet without classification head.""" + def __attrs_pre_init__(self) -> None: super().__init__() @@ -47,11 +49,13 @@ class EfficientNet(nn.Module): @arch.validator def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") self.params = self.archs[value] def _build(self) -> None: + """Builds the efficientnet backbone.""" _block_args = block_args() in_channels = 1 # BW out_channels = round_filters(32, self.params) @@ -73,8 +77,9 @@ class EfficientNet(nn.Module): for args in _block_args: args.in_channels = round_filters(args.in_channels, self.params) args.out_channels = round_filters(args.out_channels, self.params) - args.num_repeats = round_repeats(args.num_repeats, self.params) - for _ in range(args.num_repeats): + num_repeats = round_repeats(args.num_repeats, self.params) + del args.num_repeats + for _ in range(num_repeats): self._blocks.append( MBConvBlock( **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, @@ -93,6 +98,7 @@ class EfficientNet(nn.Module): ) def extract_features(self, x: Tensor) -> Tensor: + """Extracts the final feature map layer.""" x = self._conv_stem(x) for i, block in enumerate(self._blocks): stochastic_dropout_rate = self.stochastic_dropout_rate @@ -103,4 +109,5 @@ class EfficientNet(nn.Module): return x def forward(self, x: Tensor) -> Tensor: + """Returns efficientnet image features.""" return self.extract_features(x) diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index 3aa63d0..e85df87 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,76 +1,62 @@ """Mobile inverted residual block.""" -from typing import Any, Optional, Union, Tuple +from typing import Optional, Sequence, Union, Tuple +import attr import torch from torch import nn, Tensor import torch.nn.functional as F -from .utils import stochastic_depth +from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth +def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: + """Converts int to tuple.""" + return ( + (stride,) * 2 if isinstance(stride, int) else stride + ) + + +@attr.s(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - bn_momentum: float, - bn_eps: float, - se_ratio: float, - expand_ratio: int, - *args: Any, - **kwargs: Any, - ) -> None: + def __attrs_pre_init__(self) -> None: super().__init__() - self.kernel_size = kernel_size - self.stride = (stride,) * 2 if isinstance(stride, int) else stride - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self.in_channels = in_channels - self.out_channels = out_channels + in_channels: int = attr.ib() + out_channels: int = attr.ib() + kernel_size: Tuple[int, int] = attr.ib() + stride: Tuple[int, int] = attr.ib(converter=_convert_stride) + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + se_ratio: float = attr.ib() + expand_ratio: int = attr.ib() + pad: Tuple[int, int, int, int] = attr.ib(init=False) + _inverted_bottleneck: nn.Sequential = attr.ib(init=False) + _depthwise: nn.Sequential = attr.ib(init=False) + _squeeze_excite: nn.Sequential = attr.ib(init=False) + _pointwise: nn.Sequential = attr.ib(init=False) + + @pad.default + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" if self.stride == (2, 2): - self.pad = [ + return ( (self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2, - ] * 2 - else: - self.pad = [(self.kernel_size - 1) // 2] * 4 - - # Placeholders for layers. - self._inverted_bottleneck: nn.Sequential = None - self._depthwise: nn.Sequential = None - self._squeeze_excite: nn.Sequential = None - self._pointwise: nn.Sequential = None - - self._build( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - expand_ratio=expand_ratio, - se_ratio=se_ratio, - ) + ) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() - def _build( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - expand_ratio: int, - se_ratio: float, - ) -> None: - has_se = se_ratio is not None and 0.0 < se_ratio < 1.0 - inner_channels = in_channels * expand_ratio + def _build(self) -> None: + has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 + inner_channels = self.in_channels * self.expand_ratio self._inverted_bottleneck = ( - self._configure_inverted_bottleneck( - in_channels=in_channels, out_channels=inner_channels, - ) - if expand_ratio != 1 + self._configure_inverted_bottleneck(out_channels=inner_channels) + if self.expand_ratio != 1 else None ) @@ -78,31 +64,23 @@ class MBConvBlock(nn.Module): in_channels=inner_channels, out_channels=inner_channels, groups=inner_channels, - kernel_size=kernel_size, - stride=stride, ) self._squeeze_excite = ( self._configure_squeeze_excite( - in_channels=inner_channels, - out_channels=inner_channels, - se_ratio=se_ratio, + in_channels=inner_channels, out_channels=inner_channels, ) if has_se else None ) - self._pointwise = self._configure_pointwise( - in_channels=inner_channels, out_channels=out_channels - ) + self._pointwise = self._configure_pointwise(in_channels=inner_channels) - def _configure_inverted_bottleneck( - self, in_channels: int, out_channels: int, - ) -> nn.Sequential: + def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential: """Expansion phase.""" return nn.Sequential( nn.Conv2d( - in_channels=in_channels, + in_channels=self.in_channels, out_channels=out_channels, kernel_size=1, bias=False, @@ -114,19 +92,14 @@ class MBConvBlock(nn.Module): ) def _configure_depthwise( - self, - in_channels: int, - out_channels: int, - groups: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], + self, in_channels: int, out_channels: int, groups: int, ) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, + kernel_size=self.kernel_size, + stride=self.stride, groups=groups, bias=False, ), @@ -137,9 +110,9 @@ class MBConvBlock(nn.Module): ) def _configure_squeeze_excite( - self, in_channels: int, out_channels: int, se_ratio: float + self, in_channels: int, out_channels: int ) -> nn.Sequential: - num_squeezed_channels = max(1, int(in_channels * se_ratio)) + num_squeezed_channels = max(1, int(in_channels * self.se_ratio)) return nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -154,18 +127,18 @@ class MBConvBlock(nn.Module): ), ) - def _configure_pointwise( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: + def _configure_pointwise(self, in_channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, - out_channels=out_channels, + out_channels=self.out_channels, kernel_size=1, bias=False, ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), ) @@ -186,8 +159,8 @@ class MBConvBlock(nn.Module): residual = x if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) - x = F.pad(x, self.pad) + x = F.pad(x, self.pad) x = self._depthwise(x) if self._squeeze_excite is not None: diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 9202cce..37ce29e 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -15,7 +15,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding ) -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Standard attention.""" @@ -31,7 +31,6 @@ class Attention(nn.Module): dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) qkv_fn: nn.Sequential = attr.ib(init=False) - attn_fn: F.softmax = attr.ib(init=False, default=F.softmax) def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -80,7 +79,7 @@ class Attention(nn.Module): else k_mask ) q_mask = rearrange(q_mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b i -> b () () j") + k_mask = rearrange(k_mask, "b j -> b () () j") return q_mask * k_mask return @@ -129,7 +128,7 @@ class Attention(nn.Module): if self.causal: energy = self._apply_causal_mask(energy, mask, mask_value, device) - attn = self.attn_fn(energy, dim=-1) + attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 66c9c50..ce443e5 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding from text_recognizer.networks.util import load_partial_fn -@attr.s +@attr.s(eq=False) class AttentionLayers(nn.Module): """Standard transfomer layer.""" @@ -101,11 +101,11 @@ class AttentionLayers(nn.Module): return x -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Encoder(AttentionLayers): causal: bool = attr.ib(default=False, init=False) -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Decoder(AttentionLayers): causal: bool = attr.ib(default=True, init=False) |