summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/criterions/label_smoothing.py3
-rw-r--r--text_recognizer/data/base_data_module.py13
-rw-r--r--text_recognizer/data/base_dataset.py19
-rw-r--r--text_recognizer/data/build_transitions.py2
-rw-r--r--text_recognizer/data/emnist_lines.py11
-rw-r--r--text_recognizer/data/iam.py29
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py31
-rw-r--r--text_recognizer/data/sentence_generator.py7
-rw-r--r--text_recognizer/data/transforms.py1
-rw-r--r--text_recognizer/data/word_piece_mapping.py15
-rw-r--r--text_recognizer/networks/conv_transformer.py17
11 files changed, 96 insertions, 52 deletions
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py
index 74ff145..5b3a47e 100644
--- a/text_recognizer/criterions/label_smoothing.py
+++ b/text_recognizer/criterions/label_smoothing.py
@@ -21,7 +21,8 @@ class LabelSmoothingLoss(nn.Module):
self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1
) -> None:
super().__init__()
- assert 0.0 < smoothing <= 1.0
+ if not 0.0 < smoothing < 1.0:
+ raise ValueError("Smoothing must be between 0.0 and 1.0")
self.ignore_index = ignore_index
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 16a06d9..ee70176 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,13 +1,15 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Tuple, Type
+from typing import Dict, Optional, Tuple, Type, TypeVar
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
-from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_mapping import AbstractMapping
+
+T = TypeVar("T")
def load_and_print_info(data_module_class: type) -> None:
@@ -23,6 +25,7 @@ class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
mapping: Type[AbstractMapping] = attr.ib()
@@ -38,7 +41,7 @@ class BaseDataModule(LightningDataModule):
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
@classmethod
- def data_dirname(cls) -> Path:
+ def data_dirname(cls: T) -> Path:
"""Return the path to the base data directory."""
return Path(__file__).resolve().parents[2] / "data"
@@ -53,14 +56,14 @@ class BaseDataModule(LightningDataModule):
"""Prepare data for training."""
pass
- def setup(self, stage: str = None) -> None:
+ def setup(self, stage: Optional[str] = None) -> None:
"""Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and
optionally self.data_test.
Args:
- stage (Any): Variable to set splits.
+ stage (Optional[str]): Variable to set splits.
"""
pass
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 8640d92..e08130d 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -9,8 +9,7 @@ from torch.utils.data import Dataset
@attr.s
class BaseDataset(Dataset):
- """
- Base Dataset class that processes data and targets through optional transfroms.
+ r"""Base Dataset class that processes data and targets through optional transfroms.
Args:
data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images.
@@ -26,9 +25,11 @@ class BaseDataset(Dataset):
target_transform: Optional[Callable] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
def __attrs_post_init__(self) -> None:
+ """Post init constructor."""
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
@@ -60,9 +61,17 @@ class BaseDataset(Dataset):
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int
) -> Tensor:
- """
- Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens,
- and padded wiht the <p> token.
+ r"""Convert a sequence of N strings to (N, length) ndarray.
+
+ Add each string with <s> and </s> tokens, and padded wiht the <p> token.
+
+ Args:
+ strings (Sequence[str]): Sequence of strings.
+ mapping (Dict[str, int]): Mapping of characters and digits to integers.
+ length (int): Max lenght of all strings.
+
+ Returns:
+ Tensor: Target with emnist mapping indices.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
for i, string in enumerate(strings):
diff --git a/text_recognizer/data/build_transitions.py b/text_recognizer/data/build_transitions.py
index 91f8c1a..0f987ca 100644
--- a/text_recognizer/data/build_transitions.py
+++ b/text_recognizer/data/build_transitions.py
@@ -1,9 +1,7 @@
"""Builds transition graph.
Most code stolen from here:
-
https://github.com/facebookresearch/gtn_applications/blob/master/scripts/build_transitions.py
-
"""
import collections
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index d4b2b40..3ff8a54 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -8,6 +8,7 @@ import h5py
from loguru import logger as log
import numpy as np
import torch
+from torch import Tensor
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
@@ -190,7 +191,9 @@ def _get_samples_by_char(
return samples_by_char
-def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
+def _select_letter_samples_for_string(
+ string: str, samples_by_char: defaultdict
+) -> List[Tensor]:
null_image = torch.zeros((28, 28), dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
@@ -208,7 +211,7 @@ def _construct_image_from_string(
min_overlap: float,
max_overlap: float,
width: int,
-) -> torch.Tensor:
+) -> Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = _select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
@@ -218,7 +221,7 @@ def _construct_image_from_string(
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
- return torch.minimum(torch.Tensor([255]), concatenated_image)
+ return torch.minimum(Tensor([255]), concatenated_image)
def _create_dataset_of_images(
@@ -228,7 +231,7 @@ def _create_dataset_of_images(
min_overlap: float,
max_overlap: float,
dims: Tuple,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[Tensor, Tensor]:
images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
labels = []
for n in range(num_samples):
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 7278eb2..263bf8e 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -1,4 +1,8 @@
-"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+"""Class for loading the IAM dataset.
+
+Which encompasses both paragraphs and lines, with associated utilities.
+"""
+
import os
from pathlib import Path
from typing import Any, Dict, List
@@ -25,21 +29,25 @@ LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
@attr.s(auto_attribs=True)
class IAM(BaseDataModule):
- """
- "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
- which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels.
- From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+ r"""The IAM Lines dataset.
+
+ First published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray
+ levels. From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
The data split we will use is
- IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text
+ lines.
The validation set has been merged into the train set.
The train set has 7,101 lines from 326 writers.
The test set has 1,861 lines from 128 writers.
- The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+ The text lines of all data sets are mutually exclusive, thus each writer has
+ contributed to one set only.
"""
metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))
def prepare_data(self) -> None:
+ """Prepares the IAM dataset."""
if self.xml_filenames:
return
filename = download_dataset(self.metadata, DL_DATA_DIRNAME)
@@ -47,18 +55,22 @@ class IAM(BaseDataModule):
@property
def xml_filenames(self) -> List[Path]:
+ """Returns the xml filenames."""
return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
@property
def form_filenames(self) -> List[Path]:
+ """Returns the form filenames."""
return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
@property
def form_filenames_by_id(self) -> Dict[str, Path]:
+ """Returns dictionary with filename and path."""
return {filename.stem: filename for filename in self.form_filenames}
@property
def split_by_id(self) -> Dict[str, str]:
+ """Splits files into train and test."""
return {
filename.stem: "test"
if filename.stem in self.metadata["test_ids"]
@@ -76,7 +88,7 @@ class IAM(BaseDataModule):
@cachedproperty
def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]:
- """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ """Return a dict from name IAM form to list of (x1, x2, y1, y2)."""
return {
filename.stem: _get_line_regions_from_xml_file(filename)
for filename in self.xml_filenames
@@ -129,4 +141,5 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:
def download_iam() -> None:
+ """Downloads and prints IAM dataset."""
load_and_print_info(IAM)
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index b9cf90d..f253427 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -7,20 +7,11 @@ from loguru import logger as log
import numpy as np
from PIL import Image
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
)
-from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.iam_paragraphs import (
- get_dataset_properties,
- get_transform,
- get_target_transform,
- NEW_LINE_TOKEN,
- IAMParagraphs,
- IMAGE_SCALE_FACTOR,
- resize_image,
-)
from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
@@ -28,6 +19,15 @@ from text_recognizer.data.iam_lines import (
load_line_crops_and_labels,
save_images_and_labels,
)
+from text_recognizer.data.iam_paragraphs import (
+ get_dataset_properties,
+ get_target_transform,
+ get_transform,
+ IAMParagraphs,
+ IMAGE_SCALE_FACTOR,
+ NEW_LINE_TOKEN,
+ resize_image,
+)
PROCESSED_DATA_DIRNAME = (
@@ -146,7 +146,10 @@ def generate_synthetic_paragraphs(
)
if len(paragraph_label) > paragraphs_properties["label_length"]["max"]:
log.info(
- "Label longer than longest label in original IAM paragraph dataset - hence dropping."
+ (
+ "Label longer than longest label in original IAM paragraph dataset"
+ " - hence dropping."
+ )
)
continue
@@ -160,7 +163,10 @@ def generate_synthetic_paragraphs(
or paragraph_crop.width > max_paragraph_shape[1]
):
log.info(
- "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping"
+ (
+ "Crop larger than largest crop in original IAM paragraphs dataset"
+ " - hence dropping"
+ )
)
continue
@@ -213,4 +219,5 @@ def generate_random_batches(
def create_synthetic_iam_paragraphs() -> None:
+ """Creates and prints IAM Synthetic Paragraphs dataset."""
load_and_print_info(IAMSyntheticParagraphs)
diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py
index afcdbe9..8567e6d 100644
--- a/text_recognizer/data/sentence_generator.py
+++ b/text_recognizer/data/sentence_generator.py
@@ -28,7 +28,7 @@ class SentenceGenerator:
r"""Generates a word or sentences from the Brown corpus.
Sample a string from the Brown corpus of length at least one word and at most
- max_length, padding to max_length with the '_' characters if sentence is
+ max_length, padding to max_length with the '_' characters if sentence is
shorter.
Args:
@@ -39,8 +39,11 @@ class SentenceGenerator:
str: A sentence from the Brown corpus.
Raises:
- ValueError: If max_length was not specified at initialization and not
+ ValueError: If max_length was not specified at initialization and not
given as an argument.
+
+ RuntimeError: If a valid string was not generated.
+
"""
if max_length is None:
max_length = self.max_length
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 51f52de..7f3e0d1 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -37,6 +37,7 @@ class WordPiece:
self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
+ """Converts Emnist target tensor to Word piece target tensor."""
y = self.mapping.emnist_to_wordpiece_indices(x)
if len(y) < self.max_len:
pad_len = self.max_len - len(y)
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
index 2f650cd..dc56942 100644
--- a/text_recognizer/data/word_piece_mapping.py
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -1,9 +1,9 @@
"""Word piece mapping."""
from pathlib import Path
-from typing import List, Optional, Union, Set
+from typing import List, Optional, Set, Union
-import torch
from loguru import logger as log
+import torch
from torch import Tensor
from text_recognizer.data.emnist_mapping import EmnistMapping
@@ -11,6 +11,8 @@ from text_recognizer.data.iam_preprocessor import Preprocessor
class WordPieceMapping(EmnistMapping):
+ """Word piece mapping."""
+
def __init__(
self,
data_dir: Optional[Path] = None,
@@ -20,7 +22,7 @@ class WordPieceMapping(EmnistMapping):
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
- extra_symbols: Set[str] = {"\n",},
+ extra_symbols: Set[str] = {"\n"},
) -> None:
super().__init__(extra_symbols=extra_symbols)
self.data_dir = (
@@ -60,30 +62,37 @@ class WordPieceMapping(EmnistMapping):
)
def __len__(self) -> int:
+ """Return number of word pieces."""
return len(self.wordpiece_processor.tokens)
def get_token(self, index: Union[int, Tensor]) -> str:
+ """Returns token for index."""
if (index := int(index)) <= self.wordpiece_processor.num_tokens:
return self.wordpiece_processor.tokens[index]
raise KeyError(f"Index ({index}) not in mapping.")
def get_index(self, token: str) -> Tensor:
+ """Returns index of token."""
if token in self.wordpiece_processor.tokens:
return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ """Returns text from indices."""
if isinstance(indices, Tensor):
indices = indices.tolist()
return self.wordpiece_processor.to_text(indices)
def get_indices(self, text: str) -> Tensor:
+ """Returns indices of text."""
return self.wordpiece_processor.to_index(text)
def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ """Returns word pieces indices from emnist indices."""
text = "".join([self.mapping[i] for i in x])
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
def __getitem__(self, x: Union[int, Tensor]) -> str:
+ """Returns token for word piece index."""
return self.get_token(x)
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 714792e..60c0ef8 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,6 +1,6 @@
"""Vision transformer for character recognition."""
import math
-from typing import Tuple
+from typing import Tuple, Type
from torch import nn, Tensor
@@ -24,6 +24,8 @@ class ConvTransformer(nn.Module):
pad_index: Tensor,
encoder: nn.Module,
decoder: Decoder,
+ pixel_pos_embedding: Type[nn.Module],
+ token_pos_embedding: Type[nn.Module],
) -> None:
super().__init__()
self.input_dims = input_dims
@@ -43,11 +45,7 @@ class ConvTransformer(nn.Module):
out_channels=self.hidden_dim,
kernel_size=1,
),
- PositionalEncoding2D(
- hidden_dim=self.hidden_dim,
- max_h=self.input_dims[1],
- max_w=self.input_dims[2],
- ),
+ pixel_pos_embedding,
nn.Flatten(start_dim=2),
)
@@ -57,9 +55,8 @@ class ConvTransformer(nn.Module):
)
# Positional encoding for decoder tokens.
- self.token_pos_encoder = PositionalEncoding(
- hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate
- )
+ self.token_pos_embedding = token_pos_embedding
+
# Head
self.head = nn.Linear(
in_features=self.hidden_dim, out_features=self.num_classes
@@ -119,7 +116,7 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = self.token_pos_encoder(trg)
+ trg = self.token_pos_embedding(trg)
out = self.decoder(x=trg, context=src, mask=trg_mask)
logits = self.head(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]