summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/base_data_module.py14
-rw-r--r--text_recognizer/data/base_dataset.py11
-rw-r--r--text_recognizer/data/download_utils.py8
-rw-r--r--text_recognizer/data/emnist.py21
-rw-r--r--text_recognizer/data/emnist_lines.py21
-rw-r--r--text_recognizer/data/iam.py4
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py6
-rw-r--r--text_recognizer/data/iam_lines.py21
-rw-r--r--text_recognizer/data/iam_paragraphs.py18
-rw-r--r--text_recognizer/data/iam_preprocessor.py16
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py19
-rw-r--r--text_recognizer/data/make_wordpieces.py8
-rw-r--r--text_recognizer/data/mappings.py24
-rw-r--r--text_recognizer/models/base.py2
-rw-r--r--text_recognizer/models/metrics.py4
-rw-r--r--text_recognizer/models/transformer.py16
-rw-r--r--text_recognizer/models/vqvae.py2
-rw-r--r--text_recognizer/networks/conv_transformer.py3
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py15
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py139
-rw-r--r--text_recognizer/networks/transformer/attention.py7
-rw-r--r--text_recognizer/networks/transformer/layers.py6
22 files changed, 180 insertions, 205 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 408ae36..fd914b6 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ mapping: AbstractMapping = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
+ pin_memory: bool = attr.ib(default=True)
# Placeholders
data_train: BaseDataset = attr.ib(init=False, default=None)
@@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule):
data_test: BaseDataset = attr.ib(init=False, default=None)
dims: Tuple[int, ...] = attr.ib(init=False, default=None)
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- mapping: Any = attr.ib(init=False, default=None)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule):
return {
"input_dim": self.dims,
"output_dims": self.output_dims,
- "mapping": self.mapping,
}
def prepare_data(self) -> None:
@@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def val_dataloader(self) -> DataLoader:
@@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def test_dataloader(self) -> DataLoader:
@@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index c26f1c9..8640d92 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,5 +1,5 @@
"""Base PyTorch Dataset class."""
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
+from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import attr
import torch
@@ -22,14 +22,13 @@ class BaseDataset(Dataset):
data: Union[Sequence, Tensor] = attr.ib()
targets: Union[Sequence, Tensor] = attr.ib()
- transform: Callable = attr.ib()
- target_transform: Callable = attr.ib()
+ transform: Optional[Callable] = attr.ib(default=None)
+ target_transform: Optional[Callable] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
super().__init__()
def __attrs_post_init__(self) -> None:
- # TODO: refactor this
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
@@ -37,14 +36,14 @@ class BaseDataset(Dataset):
"""Return the length of the dataset."""
return len(self.data)
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
"""Return a datum and its target, after processing by transforms.
Args:
index (int): Index of a datum in the dataset.
Returns:
- Tuple[Any, Any]: Datum and target pair.
+ Tuple[Tensor, Tensor]: Datum and target pair.
"""
datum, target = self.data[index], self.targets[index]
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
index e3dc68c..8938830 100644
--- a/text_recognizer/data/download_utils.py
+++ b/text_recognizer/data/download_utils.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional
from urllib.request import urlretrieve
-from loguru import logger
+from loguru import logger as log
from tqdm import tqdm
@@ -32,7 +32,7 @@ class TqdmUpTo(tqdm):
total_size (Optional[int]): Total size in tqdm units. Defaults to None.
"""
if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.total = total_size
self.update(blocks * block_size - self.n)
@@ -62,9 +62,9 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
filename = dl_dir / metadata["filename"]
if filename.exists():
return
- logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
_download_url(metadata["url"], filename)
- logger.info("Computing the SHA-256...")
+ log.info("Computing the SHA-256...")
sha256 = _compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError(
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 2d0ac29..c6be123 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,12 +3,12 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple
import zipfile
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import toml
import torchvision.transforms as T
@@ -50,8 +50,7 @@ class EMNIST(BaseDataModule):
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
- self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
- self.dims = (1, *input_shape)
+ self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
@@ -106,7 +105,7 @@ class EMNIST(BaseDataModule):
def emnist_mapping(
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Set[str]] = None,
) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
@@ -130,7 +129,7 @@ def download_and_process_emnist() -> None:
def _process_raw_dataset(filename: str, dirname: Path) -> None:
"""Processes the raw EMNIST dataset."""
- logger.info("Unzipping EMNIST...")
+ log.info("Unzipping EMNIST...")
curdir = os.getcwd()
os.chdir(dirname)
content = zipfile.ZipFile(filename, "r")
@@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
from scipy.io import loadmat
- logger.info("Loading training data from .mat file")
+ log.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = (
data["dataset"]["train"][0, 0]["images"][0, 0]
@@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
- logger.info("Balancing classes to reduce amount of data")
+ log.info("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
- logger.info("Saving to HDF5 in a compressed format...")
+ log.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
@@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
- logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ log.info("Saving essential dataset parameters to text_recognizer/datasets...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(mapping.values())
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
@@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
with ESSENTIALS_FILENAME.open(mode="w") as f:
json.dump(essentials, f)
- logger.info("Cleaning up...")
+ log.info("Cleaning up...")
shutil.rmtree("matlab")
os.chdir(curdir)
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 7548ad5..5298726 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,11 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, Tuple
+from typing import Callable, List, Tuple
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import torch
from torchvision import transforms
@@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule):
emnist: EMNIST = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
- self.emnist = EMNIST()
- self.mapping = self.emnist.mapping
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
@@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule):
self._generate_data("test")
def setup(self, stage: str = None) -> None:
- logger.info("EMNISTLinesDataset loading data from HDF5...")
+ log.info("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
print(self.data_filename)
with h5py.File(self.data_filename, "r") as f:
@@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule):
return basic + data
def _generate_data(self, split: str) -> None:
- logger.info(f"EMNISTLines generating data for {split}...")
+ log.info(f"EMNISTLines generating data for {split}...")
sentence_generator = SentenceGenerator(
self.max_length - 2
) # Subtract by 2 because start/end token
@@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule):
if split == "train":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_train
elif split == "val":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_val
else:
samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, emnist.mapping
+ emnist.x_test, emnist.y_test, self.mapping.mapping
)
num = self.num_test
@@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def _get_samples_by_char(
- samples: np.ndarray, labels: np.ndarray, mapping: Dict
+ samples: np.ndarray, labels: np.ndarray, mapping: List
) -> defaultdict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 3982c4f..7278eb2 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -7,7 +7,7 @@ import zipfile
import attr
from boltons.cacheutils import cachedproperty
-from loguru import logger
+from loguru import logger as log
import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -92,7 +92,7 @@ class IAM(BaseDataModule):
def _extract_raw_dataset(filename: Path, dirname: Path) -> None:
- logger.info("Extracting IAM data...")
+ log.info("Extracting IAM data...")
curdir = os.getcwd()
os.chdir(dirname)
with zipfile.ZipFile(filename, "r") as f:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0e97801..ccf0759 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -4,7 +4,6 @@ from typing import Dict, List
import attr
from torch.utils.data import ConcatDataset
-from text_recognizer.data.base_dataset import BaseDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
@@ -20,6 +19,7 @@ class IAMExtendedParagraphs(BaseDataModule):
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
+ mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -27,6 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule):
word_pieces=self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
+ mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -36,7 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.dims = self.iam_paragraphs.dims
self.output_dims = self.iam_paragraphs.output_dims
- self.num_classes = self.iam_paragraphs.num_classes
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
@@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule):
"""Returns info about the dataset."""
basic = (
"IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
- f"Num classes: {len(self.num_classes)}\n"
+ f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index b7f3fdd..1c63729 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -2,15 +2,14 @@
If not created, will generate a handwritten lines dataset from the IAM paragraphs
dataset.
-
"""
import json
from pathlib import Path
import random
-from typing import Dict, List, Sequence, Tuple
+from typing import List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
@@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
@@ -48,17 +47,13 @@ class IAMLines(BaseDataModule):
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- def __attrs_post_init__(self) -> None:
- # TODO: refactor this
- self.mapping, self.inverse_mapping, _ = emnist_mapping()
-
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Cropping IAM lines regions...")
- iam = IAM()
+ log.info("Cropping IAM lines regions...")
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
@@ -66,7 +61,7 @@ class IAMLines(BaseDataModule):
shapes = np.array([crop.size for crop in crops_train + crops_test])
aspect_ratios = shapes[:, 0] / shapes[:, 1]
- logger.info("Saving images, labels, and statistics...")
+ log.info("Saving images, labels, and statistics...")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -91,7 +86,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
- labels_train, self.inverse_mapping, length=self.output_dims[0]
+ labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
@@ -110,7 +105,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
- labels_test, self.inverse_mapping, length=self.output_dims[0]
+ labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
x_test, y_test, transform=get_transform(IMAGE_WIDTH)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 0f3a2ce..6189f7d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
@@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings import WordPieceMapping
from text_recognizer.data.transforms import WordPiece
@@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- num_classes: int = attr.ib()
word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
@@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule):
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN])
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info(
+ log.info(
"Cropping IAM paragraph regions and saving them along with labels..."
)
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
properties = {}
@@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
+ strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
)
return BaseDataset(
data,
@@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule):
target_transform=get_target_transform(self.word_pieces),
)
- logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
+ log.info(f"Loading IAM paragraph regions and lines for {stage}...")
_validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
if stage == "fit" or stage is None:
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 93a13bb..bcd77b4 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -1,18 +1,16 @@
"""Preprocessor for extracting word letters from the IAM dataset.
The code is mostly stolen from:
-
https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
"""
import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union, Sequence
+from typing import List, Optional, Union, Set
import click
-from loguru import logger
+from loguru import logger as log
import torch
@@ -57,7 +55,7 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[Sequence[str]] = None,
+ special_tokens: Optional[Set[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
@@ -186,7 +184,7 @@ def cli(
/ "iam"
/ "iamdb"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
@@ -196,15 +194,15 @@ def cli(
preprocessor.extract_train_text()
processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
- logger.debug(f"Saving processed files at: {processed_dir}")
+ log.debug(f"Saving processed files at: {processed_dir}")
if save_text is not None:
- logger.info("Saving training text")
+ log.info("Saving training text")
with open(processed_dir / save_text, "w") as f:
f.write("\n".join(t for t in preprocessor.text))
if save_tokens is not None:
- logger.info("Saving tokens")
+ log.info("Saving tokens")
with open(processed_dir / save_tokens, "w") as f:
f.write("\n".join(preprocessor.tokens))
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index f00a494..c938f8b 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -3,7 +3,7 @@ import random
from typing import Any, List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image
@@ -21,6 +21,7 @@ from text_recognizer.data.iam_paragraphs import (
IMAGE_SCALE_FACTOR,
resize_image,
)
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -43,10 +44,10 @@ class IAMSyntheticParagraphs(IAMParagraphs):
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Preparing IAM lines for synthetic paragraphs dataset.")
- logger.info("Cropping IAM line regions and loading labels.")
+ log.info("Preparing IAM lines for synthetic paragraphs dataset.")
+ log.info("Cropping IAM line regions and loading labels.")
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
@@ -55,7 +56,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train]
crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test]
- logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
+ log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -64,7 +65,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def setup(self, stage: str = None) -> None:
"""Loading synthetic dataset."""
- logger.info(f"IAM Synthetic dataset steup for stage {stage}...")
+ log.info(f"IAM Synthetic dataset steup for stage {stage}...")
if stage == "fit" or stage is None:
line_crops, line_labels = load_line_crops_and_labels(
@@ -76,7 +77,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
targets = convert_strings_to_labels(
strings=paragraphs_labels,
- mapping=self.inverse_mapping,
+ mapping=self.mapping.inverse_mapping,
length=self.output_dims[0],
)
self.data_train = BaseDataset(
@@ -144,7 +145,7 @@ def generate_synthetic_paragraphs(
[line_labels[i] for i in paragraph_indices]
)
if len(paragraph_label) > paragraphs_properties["label_length"]["max"]:
- logger.info(
+ log.info(
"Label longer than longest label in original IAM paragraph dataset - hence dropping."
)
continue
@@ -158,7 +159,7 @@ def generate_synthetic_paragraphs(
paragraph_crop.height > max_paragraph_shape[0]
or paragraph_crop.width > max_paragraph_shape[1]
):
- logger.info(
+ log.info(
"Crop larger than largest crop in original IAM paragraphs dataset - hence dropping"
)
continue
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index ef9eb1b..40fbee4 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -10,7 +10,7 @@ from pathlib import Path
from typing import List, Optional, Union
import click
-from loguru import logger
+from loguru import logger as log
import sentencepiece as spm
from text_recognizer.data.iam_preprocessor import load_metadata
@@ -63,9 +63,9 @@ def save_pieces(
vocab: set,
) -> None:
"""Saves word pieces to disk."""
- logger.info(f"Generating word piece list of size {num_pieces}.")
+ log.info(f"Generating word piece list of size {num_pieces}.")
pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
- logger.info(f"Encoding vocabulary of size {len(vocab)}.")
+ log.info(f"Encoding vocabulary of size {len(vocab)}.")
encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
# Save pieces to file.
@@ -101,7 +101,7 @@ def cli(
data_dir = (
Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index b69e888..d1c64dd 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,18 +1,30 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Dict, List, Optional, Union, Set, Sequence
+from typing import Dict, List, Optional, Union, Set
import attr
-import loguru.logger as log
import torch
+from loguru import logger as log
from torch import Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam_preprocessor import Preprocessor
+@attr.s
class AbstractMapping(ABC):
+ input_size: List[int] = attr.ib(init=False)
+ mapping: List[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
@abstractmethod
def get_token(self, *args, **kwargs) -> str:
...
@@ -30,15 +42,13 @@ class AbstractMapping(ABC):
...
-@attr.s
+@attr.s(auto_attribs=True)
class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
- mapping: Sequence[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
- input_size: List[int] = attr.ib(init=False)
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
self.extra_symbols
)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index caf63c1..8ce5c37 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -12,7 +12,7 @@ from torch import Tensor
import torchmetrics
-@attr.s
+@attr.s(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 0eb42dc..f83c9e4 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -8,11 +8,11 @@ from torch import Tensor
from torchmetrics import Metric
-@attr.s
+@attr.s(eq=False)
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_indices: Set = attr.ib(converter=set)
+ ignore_indices: Set[Tensor] = attr.ib(converter=set)
error: Tensor = attr.ib(init=False)
total: Tensor = attr.ib(init=False)
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 0e01bb5..91e088d 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Sequence, Tuple, Type
+from typing import Tuple, Type, Set
import attr
import torch
@@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib()
- start_token: str = attr.ib()
- end_token: str = attr.ib()
- pad_token: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib(default=None)
+ start_token: str = attr.ib(default="<s>")
+ end_token: str = attr.ib(default="<e>")
+ pad_token: str = attr.ib(default="<p>")
start_index: Tensor = attr.ib(init=False)
end_index: Tensor = attr.ib(init=False)
pad_index: Tensor = attr.ib(init=False)
- ignore_indices: Sequence[str] = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
@@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel):
self.start_index = self.mapping.get_index(self.start_token)
self.end_index = self.mapping.get_index(self.end_token)
self.pad_index = self.mapping.get_index(self.pad_token)
- self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index e215e14..22da018 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -10,7 +10,7 @@ import wandb
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 7371be4..09cc654 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -13,7 +13,7 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s
+@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
@@ -121,6 +121,7 @@ class ConvTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
+ context = context.long()
context_mask = context != self.pad_index
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index a36150a..b8eb53b 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,4 @@
-"""Efficient net."""
+"""Efficientnet backbone."""
from typing import Tuple
import attr
@@ -12,8 +12,10 @@ from .utils import (
)
-@attr.s
+@attr.s(eq=False)
class EfficientNet(nn.Module):
+ """Efficientnet without classification head."""
+
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -47,11 +49,13 @@ class EfficientNet(nn.Module):
@arch.validator
def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ """Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
self.params = self.archs[value]
def _build(self) -> None:
+ """Builds the efficientnet backbone."""
_block_args = block_args()
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
@@ -73,8 +77,9 @@ class EfficientNet(nn.Module):
for args in _block_args:
args.in_channels = round_filters(args.in_channels, self.params)
args.out_channels = round_filters(args.out_channels, self.params)
- args.num_repeats = round_repeats(args.num_repeats, self.params)
- for _ in range(args.num_repeats):
+ num_repeats = round_repeats(args.num_repeats, self.params)
+ del args.num_repeats
+ for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
**args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
@@ -93,6 +98,7 @@ class EfficientNet(nn.Module):
)
def extract_features(self, x: Tensor) -> Tensor:
+ """Extracts the final feature map layer."""
x = self._conv_stem(x)
for i, block in enumerate(self._blocks):
stochastic_dropout_rate = self.stochastic_dropout_rate
@@ -103,4 +109,5 @@ class EfficientNet(nn.Module):
return x
def forward(self, x: Tensor) -> Tensor:
+ """Returns efficientnet image features."""
return self.extract_features(x)
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 3aa63d0..e85df87 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,76 +1,62 @@
"""Mobile inverted residual block."""
-from typing import Any, Optional, Union, Tuple
+from typing import Optional, Sequence, Union, Tuple
+import attr
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from .utils import stochastic_depth
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+ """Converts int to tuple."""
+ return (
+ (stride,) * 2 if isinstance(stride, int) else stride
+ )
+
+
+@attr.s(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- bn_momentum: float,
- bn_eps: float,
- se_ratio: float,
- expand_ratio: int,
- *args: Any,
- **kwargs: Any,
- ) -> None:
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.kernel_size = kernel_size
- self.stride = (stride,) * 2 if isinstance(stride, int) else stride
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self.in_channels = in_channels
- self.out_channels = out_channels
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
if self.stride == (2, 2):
- self.pad = [
+ return (
(self.kernel_size - 1) // 2 - 1,
(self.kernel_size - 1) // 2,
- ] * 2
- else:
- self.pad = [(self.kernel_size - 1) // 2] * 4
-
- # Placeholders for layers.
- self._inverted_bottleneck: nn.Sequential = None
- self._depthwise: nn.Sequential = None
- self._squeeze_excite: nn.Sequential = None
- self._pointwise: nn.Sequential = None
-
- self._build(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- expand_ratio=expand_ratio,
- se_ratio=se_ratio,
- )
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
- def _build(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- expand_ratio: int,
- se_ratio: float,
- ) -> None:
- has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
- inner_channels = in_channels * expand_ratio
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
self._inverted_bottleneck = (
- self._configure_inverted_bottleneck(
- in_channels=in_channels, out_channels=inner_channels,
- )
- if expand_ratio != 1
+ self._configure_inverted_bottleneck(out_channels=inner_channels)
+ if self.expand_ratio != 1
else None
)
@@ -78,31 +64,23 @@ class MBConvBlock(nn.Module):
in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
- kernel_size=kernel_size,
- stride=stride,
)
self._squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels,
- out_channels=inner_channels,
- se_ratio=se_ratio,
+ in_channels=inner_channels, out_channels=inner_channels,
)
if has_se
else None
)
- self._pointwise = self._configure_pointwise(
- in_channels=inner_channels, out_channels=out_channels
- )
+ self._pointwise = self._configure_pointwise(in_channels=inner_channels)
- def _configure_inverted_bottleneck(
- self, in_channels: int, out_channels: int,
- ) -> nn.Sequential:
+ def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
"""Expansion phase."""
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=self.in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
@@ -114,19 +92,14 @@ class MBConvBlock(nn.Module):
)
def _configure_depthwise(
- self,
- in_channels: int,
- out_channels: int,
- groups: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
+ self, in_channels: int, out_channels: int, groups: int,
) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
groups=groups,
bias=False,
),
@@ -137,9 +110,9 @@ class MBConvBlock(nn.Module):
)
def _configure_squeeze_excite(
- self, in_channels: int, out_channels: int, se_ratio: float
+ self, in_channels: int, out_channels: int
) -> nn.Sequential:
- num_squeezed_channels = max(1, int(in_channels * se_ratio))
+ num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -154,18 +127,18 @@ class MBConvBlock(nn.Module):
),
)
- def _configure_pointwise(
- self, in_channels: int, out_channels: int
- ) -> nn.Sequential:
+ def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
- out_channels=out_channels,
+ out_channels=self.out_channels,
kernel_size=1,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
)
@@ -186,8 +159,8 @@ class MBConvBlock(nn.Module):
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
- x = F.pad(x, self.pad)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None:
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 9202cce..37ce29e 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -15,7 +15,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
)
-@attr.s
+@attr.s(eq=False)
class Attention(nn.Module):
"""Standard attention."""
@@ -31,7 +31,6 @@ class Attention(nn.Module):
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
qkv_fn: nn.Sequential = attr.ib(init=False)
- attn_fn: F.softmax = attr.ib(init=False, default=F.softmax)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
@@ -80,7 +79,7 @@ class Attention(nn.Module):
else k_mask
)
q_mask = rearrange(q_mask, "b i -> b () i ()")
- k_mask = rearrange(k_mask, "b i -> b () () j")
+ k_mask = rearrange(k_mask, "b j -> b () () j")
return q_mask * k_mask
return
@@ -129,7 +128,7 @@ class Attention(nn.Module):
if self.causal:
energy = self._apply_causal_mask(energy, mask, mask_value, device)
- attn = self.attn_fn(energy, dim=-1)
+ attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 66c9c50..ce443e5 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
from text_recognizer.networks.util import load_partial_fn
-@attr.s
+@attr.s(eq=False)
class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
@@ -101,11 +101,11 @@ class AttentionLayers(nn.Module):
return x
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
causal: bool = attr.ib(default=True, init=False)