summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
commit34098ccbbbf6379c0bd29a987440b8479c743746 (patch)
treea8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer
parentc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff)
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/criterions/label_smoothing.py (renamed from text_recognizer/criterions/label_smoothing_loss.py)0
-rw-r--r--text_recognizer/data/base_dataset.py1
-rw-r--r--text_recognizer/data/emnist.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py23
-rw-r--r--text_recognizer/data/iam_lines.py6
-rw-r--r--text_recognizer/data/iam_paragraphs.py7
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py12
-rw-r--r--text_recognizer/models/base.py31
-rw-r--r--text_recognizer/models/transformer.py26
-rw-r--r--text_recognizer/networks/base.py18
-rw-r--r--text_recognizer/networks/conv_transformer.py (renamed from text_recognizer/networks/cnn_tranformer.py)27
11 files changed, 72 insertions, 81 deletions
diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing.py
index 40a7609..40a7609 100644
--- a/text_recognizer/criterions/label_smoothing_loss.py
+++ b/text_recognizer/criterions/label_smoothing.py
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 4318dfb..c26f1c9 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -29,6 +29,7 @@ class BaseDataset(Dataset):
super().__init__()
def __attrs_post_init__(self) -> None:
+ # TODO: refactor this
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index d51a42a..2d0ac29 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -46,7 +46,7 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- train_fraction: float = attr.ib()
+ train_fraction: float = attr.ib(default=0.8)
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 886e37e..58c7369 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -13,23 +13,24 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- train_fraction: float = attr.ib()
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- self.batch_size,
- self.num_workers,
- self.train_fraction,
- self.augment,
- self.word_pieces,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ train_fraction=self.train_fraction,
+ augment=self.augment,
+ word_pieces=self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- self.batch_size,
- self.num_workers,
- self.train_fraction,
- self.augment,
- self.word_pieces,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ train_fraction=self.train_fraction,
+ augment=self.augment,
+ word_pieces=self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index e45e5c8..705cfa3 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -34,6 +34,7 @@ SEED = 4711
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+MAX_LABEL_LENGTH = 89
@attr.s(auto_attribs=True)
@@ -42,11 +43,12 @@ class IAMLines(BaseDataModule):
augment: bool = attr.ib(default=True)
fraction: float = attr.ib(default=0.8)
+ dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH))
+ output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
def __attrs_post_init__(self) -> None:
+ # TODO: refactor this
self.mapping, self.inverse_mapping, _ = emnist_mapping()
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (89, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index bdfb490..9977978 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -41,6 +41,8 @@ class IAMParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
+ dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH))
+ output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
@@ -49,11 +51,6 @@ class IAMParagraphs(BaseDataModule):
if self.word_pieces:
self.mapping = WordPieceMapping()
- self.train_fraction = train_fraction
-
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
-
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 00fa2b6..a3697e7 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,6 +2,7 @@
import random
from typing import Any, List, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image
@@ -33,19 +34,10 @@ PROCESSED_DATA_DIRNAME = (
)
+@attr.s(auto_attribs=True)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces)
-
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
if PROCESSED_DATA_DIRNAME.exists():
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index f95df0f..3b83056 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
-import pytorch_lightning as LightningModule
+from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.networks.base import BaseNetwork
+
@attr.s
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
- network: Type[nn.Module] = attr.ib()
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ network: Type[BaseNetwork] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
@@ -24,23 +29,13 @@ class BaseLitModel(LightningModule):
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn = attr.ib(init=False)
-
- train_acc = attr.ib(init=False)
- val_acc = attr.ib(init=False)
- test_acc = attr.ib(init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- self.loss_fn = self._configure_criterion()
+ loss_fn: Type[nn.Module] = attr.ib(init=False)
- # Accuracy metric
- self.train_acc = torchmetrics.Accuracy()
- self.val_acc = torchmetrics.Accuracy()
- self.test_acc = torchmetrics.Accuracy()
+ train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ @loss_fn.default
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 8c9fe8a..f5cb491 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,13 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple, Type
+from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
import attr
import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping_config: DictConfig = attr.ib(converter=DictConfig)
+ ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
+ val_cer: CharacterErrorRate = attr.ib(init=False)
+ test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.mapping, ignore_tokens = self._configure_mapping()
- self.val_cer = CharacterErrorRate(ignore_tokens)
- self.test_cer = CharacterErrorRate(ignore_tokens)
+ self.val_cer = CharacterErrorRate(self.ignore_tokens)
+ self.test_cer = CharacterErrorRate(self.ignore_tokens)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- @staticmethod
- def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
- """Configure mapping."""
- # TODO: Fix me!!!
- # Load config with hydra
- mapping, inverse_mapping, _ = emnist_mapping(["\n"])
- start_index = inverse_mapping["<s>"]
- end_index = inverse_mapping["<e>"]
- pad_index = inverse_mapping["<p>"]
- ignore_tokens = [start_index, end_index, pad_index]
- # TODO: add case for sentence pieces
- return mapping, ignore_tokens
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
new file mode 100644
index 0000000..07b6a32
--- /dev/null
+++ b/text_recognizer/networks/base.py
@@ -0,0 +1,18 @@
+"""Base network with required methods."""
+from abc import abstractmethod
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class BaseNetwork(nn.Module):
+ """Base network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ @abstractmethod
+ def predict(self, x: Tensor) -> Tensor:
+ """Return token indices for predictions."""
+ ...
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/conv_transformer.py
index ce7ec43..4acdc36 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -7,6 +7,7 @@ import torch
from torch import nn, Tensor
from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.base import BaseNetwork
from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
@@ -15,39 +16,37 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s
-class Reader(nn.Module):
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
+@attr.s(auto_attribs=True)
+class ConvTransformer(BaseNetwork):
# Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
- padding_idx: int = attr.ib()
start_token: str = attr.ib()
- start_index: int = attr.ib(init=False)
+ start_index: Tensor = attr.ib(init=False)
end_token: str = attr.ib()
- end_index: int = attr.ib(init=False)
+ end_index: Tensor = attr.ib(init=False)
pad_token: str = attr.ib()
- pad_index: int = attr.ib(init=False)
+ pad_index: Tensor = attr.ib(init=False)
# Modules.
encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
latent_encoder: nn.Sequential = attr.ib(init=False)
token_embedding: nn.Embedding = attr.ib(init=False)
token_pos_encoder: PositionalEncoding = attr.ib(init=False)
head: nn.Linear = attr.ib(init=False)
- mapping: Type[AbstractMapping] = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = int(self.mapping.get_index(self.start_token))
- self.end_index = int(self.mapping.get_index(self.end_token))
- self.pad_index = int(self.mapping.get_index(self.pad_token))
+ self.start_index = self.mapping.get_index(self.start_token)
+ self.end_index = self.mapping.get_index(self.end_token)
+ self.pad_index = self.mapping.get_index(self.pad_token)
+
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -130,7 +129,7 @@ class Reader(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- context_mask = context != self.padding_idx
+ context_mask = context != self.pad_index
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)