diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/callbacks/wandb_callbacks.py | 95 | ||||
-rw-r--r-- | text_recognizer/data/base_data_module.py | 29 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 22 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 35 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 33 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 22 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 32 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 5 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 15 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 26 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 34 | ||||
-rw-r--r-- | text_recognizer/networks/util.py | 2 |
13 files changed, 188 insertions, 168 deletions
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py index 4186b4a..d9d81f6 100644 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -93,6 +93,40 @@ class LogTextPredictions(Callback): def __attrs_pre_init__(self) -> None: super().__init__() + def _log_predictions( + stage: str, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs the predicted text contained in the images.""" + if not self.ready: + return None + + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + # Get a validation batch from the validation dataloader. + samples = next(iter(trainer.datamodule.val_dataloader())) + imgs, labels = samples + + imgs = imgs.to(device=pl_module.device) + logits = pl_module(imgs) + + mapping = pl_module.mapping + experiment.log( + { + f"OCR/{experiment.name}/{stage}": [ + wandb.Image( + img, + caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", + ) + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], + ) + ] + } + ) + def on_sanity_check_start( self, trainer: Trainer, pl_module: LightningModule ) -> None: @@ -107,6 +141,27 @@ class LogTextPredictions(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" + self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + + +@attr.s +class LogReconstuctedImages(Callback): + """Log reconstructions of images.""" + + num_samples: int = attr.ib(default=8) + ready: bool = attr.ib(default=True) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def _log_reconstruction( + self, stage: str, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs the reconstructions.""" if not self.ready: return None @@ -115,24 +170,42 @@ class LogTextPredictions(Callback): # Get a validation batch from the validation dataloader. samples = next(iter(trainer.datamodule.val_dataloader())) - imgs, labels = samples + imgs, _ = samples imgs = imgs.to(device=pl_module.device) - logits = pl_module(imgs) + reconstructions = pl_module(imgs) - mapping = pl_module.mapping experiment.log( { - f"Images/{experiment.name}": [ - wandb.Image( - img, - caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", - ) - for img, pred, label in zip( + f"Reconstructions/{experiment.name}/{stage}": [ + [ + wandb.Image(img), + wandb.Image(rec), + ] + for img, rec in zip( imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], + reconstructions[: self.num_samples], ) ] } ) + + def on_sanity_check_start( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Sets ready attribute.""" + self.ready = False + + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs predictions on validation epoch end.""" + self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index de5628f..18b1996 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,13 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Dict +from typing import Any, Dict, Tuple import attr -import pytorch_lightning as LightningDataModule +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from text_recognizer.data.base_dataset import BaseDataset + def load_and_print_info(data_module_class: type) -> None: """Load dataset and print dataset information.""" @@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None: class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - batch_size: int = attr.ib(default=16) - num_workers: int = attr.ib(default=0) - def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self) -> None: - # Placeholders for subclasses. - self.dims = None - self.output_dims = None - self.mapping = None + batch_size: int = attr.ib(default=16) + num_workers: int = attr.ib(default=0) + + # Placeholders + data_train: BaseDataset = attr.ib(init=False, default=None) + data_val: BaseDataset = attr.ib(init=False, default=None) + data_test: BaseDataset = attr.ib(init=False, default=None) + dims: Tuple[int, ...] = attr.ib(init=False, default=None) + output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) + mapping: Any = attr.ib(init=False, default=None) + inverse_mapping: Dict[str, int] = attr.ib(init=False) @classmethod def data_dirname(cls) -> Path: @@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule): stage (Any): Variable to set splits. """ - self.data_train = None - self.data_val = None - self.data_test = None + pass def train_dataloader(self) -> DataLoader: """Retun DataLoader for train data.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 824b947..d51a42a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,9 +3,10 @@ import json import os from pathlib import Path import shutil -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple import zipfile +import attr import h5py from loguru import logger import numpy as np @@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +@attr.s(auto_attribs=True) class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. @@ -44,18 +46,12 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __init__( - self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 - ) -> None: - super().__init__(batch_size, num_workers) - self.train_fraction = train_fraction - self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping() - self.data_train = None - self.data_val = None - self.data_test = None - self.transform = T.Compose([T.ToTensor()]) - self.dims = (1, *self.input_shape) - self.output_dims = (1,) + train_fraction: float = attr.ib() + transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) + + def __attrs_post_init__(self) -> None: + self.mapping, self.inverse_mapping, input_shape = emnist_mapping() + self.dims = (1, *input_shape) def prepare_data(self) -> None: """Downloads dataset if not present.""" diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 9650198..4747508 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -3,6 +3,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, Tuple +import attr import h5py from loguru import logger import numpy as np @@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines +@attr.s(auto_attribs=True) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" - def __init__( - self, - augment: bool = True, - batch_size: int = 128, - num_workers: int = 0, - max_length: int = 32, - min_overlap: float = 0.0, - max_overlap: float = 0.33, - num_train: int = 10_000, - num_val: int = 2_000, - num_test: int = 2_000, - ) -> None: - super().__init__(batch_size, num_workers) - - self.augment = augment - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_train = num_train - self.num_val = num_val - self.num_test = num_test + augment: bool = attr.ib(default=True) + max_length: int = attr.ib(default=128) + min_overlap: float = attr.ib(default=0.0) + max_overlap: float = attr.ib(default=0.33) + num_train: int = attr.ib(default=10_000) + num_val: int = attr.ib(default=2_000) + num_test: int = attr.ib(default=2_000) + emnist: EMNIST = attr.ib(init=False, default=None) + def __attrs_post_init__(self) -> None: self.emnist = EMNIST() self.mapping = self.emnist.mapping @@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule): raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None @property def data_filename(self) -> Path: diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 261c8d3..3982c4f 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile +import attr from boltons.cacheutils import cachedproperty from loguru import logger import toml @@ -22,6 +23,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. +@attr.s(auto_attribs=True) class IAM(BaseDataModule): """ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, @@ -35,9 +37,7 @@ class IAM(BaseDataModule): The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ - def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: - super().__init__(batch_size, num_workers) - self.metadata = toml.load(METADATA_FILENAME) + metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME)) def prepare_data(self) -> None: if self.xml_filenames: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0a30a42..886e37e 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,4 +1,7 @@ """IAM original and sythetic dataset class.""" +from typing import Dict, List + +import attr from torch.utils.data import ConcatDataset from text_recognizer.data.base_dataset import BaseDataset @@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +@attr.s(auto_attribs=True) class IAMExtendedParagraphs(BaseDataModule): - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers) + train_fraction: float = attr.ib() + word_pieces: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, word_pieces, + self.batch_size, + self.num_workers, + self.train_fraction, + self.augment, + self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, word_pieces, + self.batch_size, + self.num_workers, + self.train_fraction, + self.augment, + self.word_pieces, ) self.dims = self.iam_paragraphs.dims @@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule): self.mapping = self.iam_paragraphs.mapping self.inverse_mapping = self.iam_paragraphs.inverse_mapping - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None - def prepare_data(self) -> None: """Prepares the paragraphs data.""" self.iam_paragraphs.prepare_data() diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 9c78a22..e45e5c8 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,8 +7,9 @@ dataset. import json from pathlib import Path import random -from typing import List, Sequence, Tuple +from typing import Dict, List, Sequence, Tuple +import attr from loguru import logger from PIL import Image, ImageFile, ImageOps import numpy as np @@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56 IMAGE_WIDTH = 1024 +@attr.s(auto_attribs=True) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - def __init__( - self, - augment: bool = True, - fraction: float = 0.8, - batch_size: int = 128, - num_workers: int = 0, - ) -> None: - # TODO: add transforms - super().__init__(batch_size, num_workers) - self.augment = augment - self.fraction = fraction + augment: bool = attr.ib(default=True) + fraction: float = attr.ib(default=0.8) + + def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping() self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) self.output_dims = (89, 1) - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index fe60e99..445b788 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,6 +3,7 @@ import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple +import attr from loguru import logger import numpy as np from PIL import Image, ImageOps @@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = 682 +@attr.s(auto_attribs=True) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers) - self.augment = augment - self.word_pieces = word_pieces + augment: bool = attr.ib(default=True) + train_fraction: float = attr.ib(default=0.8) + word_pieces: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping( extra_symbols=[NEW_LINE_TOKEN] ) - if word_pieces: + if self.word_pieces: self.mapping = WordPieceMapping() self.train_fraction = train_fraction self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) self.output_dims = (MAX_LABEL_LENGTH, 1) - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None def prepare_data(self) -> None: """Create data for training/testing.""" @@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: ), T.ColorJitter(brightness=(0.8, 1.6)), T.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8dc7a36..f95df0f 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -5,7 +5,7 @@ import attr import hydra import loguru.logger as log from omegaconf import DictConfig -import pytorch_lightning as pl +import pytorch_lightning as LightningModule import torch from torch import nn from torch import Tensor @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class BaseLitModel(pl.LightningModule): +class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule): """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 58d0537..4117ae2 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,18 +1,23 @@ """Character Error Rate (CER).""" -from typing import Sequence +from typing import Set, Sequence +import attr import editdistance import torch from torch import Tensor -import torchmetrics +from torchmetrics import Metric -class CharacterErrorRate(torchmetrics.Metric): +@attr.s +class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + ignore_tokens: Set = attr.ib(converter=set) + error: Tensor = attr.ib(init=False) + total: Tensor = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: super().__init__() - self.ignore_tokens = set(ignore_tokens) self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index ea54d83..8c9fe8a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -2,35 +2,24 @@ from typing import Dict, List, Optional, Union, Tuple, Type import attr +import hydra from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -@attr.s -class TransformerLitModel(LitBaseModel): +@attr.s(auto_attribs=True) +class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - network: Type[nn.Module] = attr.ib() - criterion_config: DictConfig = attr.ib(converter=DictConfig) - optimizer_config: DictConfig = attr.ib(converter=DictConfig) - lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) - monitor: str = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + mapping_config: DictConfig = attr.ib(converter=DictConfig) def __attrs_post_init__(self) -> None: - super().__init__( - network=self.network, - optimizer_config=self.optimizer_config, - lr_scheduler_config=self.lr_scheduler_config, - criterion_config=self.criterion_config, - monitor=self.monitor, - ) - self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.mapping, ignore_tokens = self._configure_mapping() self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) @@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel): return self.network.predict(data) @staticmethod - def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: + def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: """Configure mapping.""" # TODO: Fix me!!! + # Load config with hydra mapping, inverse_mapping, _ = emnist_mapping(["\n"]) start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 7dc950f..0172163 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -1,49 +1,23 @@ """PyTorch Lightning model for base Transformers.""" from typing import Any, Dict, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn from torch import Tensor import wandb -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -class LitVQVAEModel(LitBaseModel): +@attr.s(auto_attribs=True) +class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val/loss", - *args: Any, - **kwargs: Dict, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) - def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction( - self, data: Tensor, reconstructions: Tensor, title: str - ) -> None: - """Logs prediction on image with wandb.""" - try: - self.logger.experiment.log( - { - title: [ - wandb.Image(data[0]), - wandb.Image(reconstructions[0]), - ] - } - ) - except AttributeError: - pass - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 109bf4d..85094f1 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,4 +1,4 @@ -"""Miscellaneous neural network functionality.""" +"""Miscellaneous neural network utility functionality.""" from typing import Type from torch import nn |