diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-29 23:59:52 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-29 23:59:52 +0200 |
commit | 34098ccbbbf6379c0bd29a987440b8479c743746 (patch) | |
tree | a8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer | |
parent | c032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff) |
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/criterions/label_smoothing.py (renamed from text_recognizer/criterions/label_smoothing_loss.py) | 0 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 23 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 12 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 31 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 26 | ||||
-rw-r--r-- | text_recognizer/networks/base.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/conv_transformer.py (renamed from text_recognizer/networks/cnn_tranformer.py) | 27 |
11 files changed, 72 insertions, 81 deletions
diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing.py index 40a7609..40a7609 100644 --- a/text_recognizer/criterions/label_smoothing_loss.py +++ b/text_recognizer/criterions/label_smoothing.py diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 4318dfb..c26f1c9 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -29,6 +29,7 @@ class BaseDataset(Dataset): super().__init__() def __attrs_post_init__(self) -> None: + # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index d51a42a..2d0ac29 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -46,7 +46,7 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - train_fraction: float = attr.ib() + train_fraction: float = attr.ib(default=0.8) transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 886e37e..58c7369 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -13,23 +13,24 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs @attr.s(auto_attribs=True) class IAMExtendedParagraphs(BaseDataModule): - train_fraction: float = attr.ib() + augment: bool = attr.ib(default=True) + train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( - self.batch_size, - self.num_workers, - self.train_fraction, - self.augment, - self.word_pieces, + batch_size=self.batch_size, + num_workers=self.num_workers, + train_fraction=self.train_fraction, + augment=self.augment, + word_pieces=self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - self.batch_size, - self.num_workers, - self.train_fraction, - self.augment, - self.word_pieces, + batch_size=self.batch_size, + num_workers=self.num_workers, + train_fraction=self.train_fraction, + augment=self.augment, + word_pieces=self.word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index e45e5c8..705cfa3 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -34,6 +34,7 @@ SEED = 4711 PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" IMAGE_HEIGHT = 56 IMAGE_WIDTH = 1024 +MAX_LABEL_LENGTH = 89 @attr.s(auto_attribs=True) @@ -42,11 +43,12 @@ class IAMLines(BaseDataModule): augment: bool = attr.ib(default=True) fraction: float = attr.ib(default=0.8) + dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)) + output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) def __attrs_post_init__(self) -> None: + # TODO: refactor this self.mapping, self.inverse_mapping, _ = emnist_mapping() - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (89, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index bdfb490..9977978 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -41,6 +41,8 @@ class IAMParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) + dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)) + output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping( @@ -49,11 +51,6 @@ class IAMParagraphs(BaseDataModule): if self.word_pieces: self.mapping = WordPieceMapping() - self.train_fraction = train_fraction - - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) - def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 00fa2b6..a3697e7 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,6 +2,7 @@ import random from typing import Any, List, Sequence, Tuple +import attr from loguru import logger import numpy as np from PIL import Image @@ -33,19 +34,10 @@ PROCESSED_DATA_DIRNAME = ( ) +@attr.s(auto_attribs=True) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces) - def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f95df0f..3b83056 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type import attr import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig -import pytorch_lightning as LightningModule +from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.networks.base import BaseNetwork + @attr.s class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - network: Type[nn.Module] = attr.ib() + def __attrs_pre_init__(self) -> None: + super().__init__() + + network: Type[BaseNetwork] = attr.ib() criterion_config: DictConfig = attr.ib(converter=DictConfig) optimizer_config: DictConfig = attr.ib(converter=DictConfig) lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) @@ -24,23 +29,13 @@ class BaseLitModel(LightningModule): interval: str = attr.ib() monitor: str = attr.ib(default="val/loss") - loss_fn = attr.ib(init=False) - - train_acc = attr.ib(init=False) - val_acc = attr.ib(init=False) - test_acc = attr.ib(init=False) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def __attrs_post_init__(self) -> None: - self.loss_fn = self._configure_criterion() + loss_fn: Type[nn.Module] = attr.ib(init=False) - # Accuracy metric - self.train_acc = torchmetrics.Accuracy() - self.val_acc = torchmetrics.Accuracy() - self.test_acc = torchmetrics.Accuracy() + train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + @loss_fn.default def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 8c9fe8a..f5cb491 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,13 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Union, Tuple, Type +from typing import Dict, List, Optional, Sequence, Union, Tuple, Type import attr import hydra from omegaconf import DictConfig from torch import nn, Tensor -from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping_config: DictConfig = attr.ib(converter=DictConfig) + ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",)) + val_cer: CharacterErrorRate = attr.ib(init=False) + test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.mapping, ignore_tokens = self._configure_mapping() - self.val_cer = CharacterErrorRate(ignore_tokens) - self.test_cer = CharacterErrorRate(ignore_tokens) + self.val_cer = CharacterErrorRate(self.ignore_tokens) + self.test_cer = CharacterErrorRate(self.ignore_tokens) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - @staticmethod - def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: - """Configure mapping.""" - # TODO: Fix me!!! - # Load config with hydra - mapping, inverse_mapping, _ = emnist_mapping(["\n"]) - start_index = inverse_mapping["<s>"] - end_index = inverse_mapping["<e>"] - pad_index = inverse_mapping["<p>"] - ignore_tokens = [start_index, end_index, pad_index] - # TODO: add case for sentence pieces - return mapping, ignore_tokens - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py new file mode 100644 index 0000000..07b6a32 --- /dev/null +++ b/text_recognizer/networks/base.py @@ -0,0 +1,18 @@ +"""Base network with required methods.""" +from abc import abstractmethod + +import attr +from torch import nn, Tensor + + +@attr.s +class BaseNetwork(nn.Module): + """Base network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + @abstractmethod + def predict(self, x: Tensor) -> Tensor: + """Return token indices for predictions.""" + ... diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/conv_transformer.py index ce7ec43..4acdc36 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -7,6 +7,7 @@ import torch from torch import nn, Tensor from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.base import BaseNetwork from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( @@ -15,39 +16,37 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s -class Reader(nn.Module): - def __attrs_pre_init__(self) -> None: - super().__init__() - +@attr.s(auto_attribs=True) +class ConvTransformer(BaseNetwork): # Parameters and placeholders, input_dims: Tuple[int, int, int] = attr.ib() hidden_dim: int = attr.ib() dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() - padding_idx: int = attr.ib() start_token: str = attr.ib() - start_index: int = attr.ib(init=False) + start_index: Tensor = attr.ib(init=False) end_token: str = attr.ib() - end_index: int = attr.ib(init=False) + end_index: Tensor = attr.ib(init=False) pad_token: str = attr.ib() - pad_index: int = attr.ib(init=False) + pad_index: Tensor = attr.ib(init=False) # Modules. encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + latent_encoder: nn.Sequential = attr.ib(init=False) token_embedding: nn.Embedding = attr.ib(init=False) token_pos_encoder: PositionalEncoding = attr.ib(init=False) head: nn.Linear = attr.ib(init=False) - mapping: Type[AbstractMapping] = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -130,7 +129,7 @@ class Reader(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - context_mask = context != self.padding_idx + context_mask = context != self.pad_index context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) out = self.decoder(x=context, context=z, mask=context_mask) |