From 3a21c29e2eff4378c63717f8920ca3ccbfef013c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 3 Oct 2021 00:31:00 +0200 Subject: Lint files --- text_recognizer/criterions/label_smoothing.py | 3 ++- text_recognizer/data/base_data_module.py | 13 ++++++---- text_recognizer/data/base_dataset.py | 19 +++++++++++---- text_recognizer/data/build_transitions.py | 2 -- text_recognizer/data/emnist_lines.py | 11 ++++++--- text_recognizer/data/iam.py | 29 ++++++++++++++++------ text_recognizer/data/iam_synthetic_paragraphs.py | 31 +++++++++++++++--------- text_recognizer/data/sentence_generator.py | 7 ++++-- text_recognizer/data/transforms.py | 1 + text_recognizer/data/word_piece_mapping.py | 15 +++++++++--- text_recognizer/networks/conv_transformer.py | 17 ++++++------- 11 files changed, 96 insertions(+), 52 deletions(-) diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py index 74ff145..5b3a47e 100644 --- a/text_recognizer/criterions/label_smoothing.py +++ b/text_recognizer/criterions/label_smoothing.py @@ -21,7 +21,8 @@ class LabelSmoothingLoss(nn.Module): self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 ) -> None: super().__init__() - assert 0.0 < smoothing <= 1.0 + if not 0.0 < smoothing < 1.0: + raise ValueError("Smoothing must be between 0.0 and 1.0") self.ignore_index = ignore_index self.confidence = 1.0 - smoothing self.smoothing = smoothing diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 16a06d9..ee70176 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,13 +1,15 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Dict, Tuple, Type +from typing import Dict, Optional, Tuple, Type, TypeVar import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from text_recognizer.data.base_mapping import AbstractMapping from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_mapping import AbstractMapping + +T = TypeVar("T") def load_and_print_info(data_module_class: type) -> None: @@ -23,6 +25,7 @@ class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" def __attrs_pre_init__(self) -> None: + """Pre init constructor.""" super().__init__() mapping: Type[AbstractMapping] = attr.ib() @@ -38,7 +41,7 @@ class BaseDataModule(LightningDataModule): output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) @classmethod - def data_dirname(cls) -> Path: + def data_dirname(cls: T) -> Path: """Return the path to the base data directory.""" return Path(__file__).resolve().parents[2] / "data" @@ -53,14 +56,14 @@ class BaseDataModule(LightningDataModule): """Prepare data for training.""" pass - def setup(self, stage: str = None) -> None: + def setup(self, stage: Optional[str] = None) -> None: """Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. Args: - stage (Any): Variable to set splits. + stage (Optional[str]): Variable to set splits. """ pass diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 8640d92..e08130d 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -9,8 +9,7 @@ from torch.utils.data import Dataset @attr.s class BaseDataset(Dataset): - """ - Base Dataset class that processes data and targets through optional transfroms. + r"""Base Dataset class that processes data and targets through optional transfroms. Args: data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. @@ -26,9 +25,11 @@ class BaseDataset(Dataset): target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: + """Pre init constructor.""" super().__init__() def __attrs_post_init__(self) -> None: + """Post init constructor.""" if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -60,9 +61,17 @@ class BaseDataset(Dataset): def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int ) -> Tensor: - """ - Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, - and padded wiht the

token. + r"""Convert a sequence of N strings to (N, length) ndarray. + + Add each string with and tokens, and padded wiht the

token. + + Args: + strings (Sequence[str]): Sequence of strings. + mapping (Dict[str, int]): Mapping of characters and digits to integers. + length (int): Max lenght of all strings. + + Returns: + Tensor: Target with emnist mapping indices. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): diff --git a/text_recognizer/data/build_transitions.py b/text_recognizer/data/build_transitions.py index 91f8c1a..0f987ca 100644 --- a/text_recognizer/data/build_transitions.py +++ b/text_recognizer/data/build_transitions.py @@ -1,9 +1,7 @@ """Builds transition graph. Most code stolen from here: - https://github.com/facebookresearch/gtn_applications/blob/master/scripts/build_transitions.py - """ import collections diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index d4b2b40..3ff8a54 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -8,6 +8,7 @@ import h5py from loguru import logger as log import numpy as np import torch +from torch import Tensor from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -190,7 +191,9 @@ def _get_samples_by_char( return samples_by_char -def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): +def _select_letter_samples_for_string( + string: str, samples_by_char: defaultdict +) -> List[Tensor]: null_image = torch.zeros((28, 28), dtype=torch.uint8) sample_image_by_char = {} for char in string: @@ -208,7 +211,7 @@ def _construct_image_from_string( min_overlap: float, max_overlap: float, width: int, -) -> torch.Tensor: +) -> Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = _select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape @@ -218,7 +221,7 @@ def _construct_image_from_string( for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width - return torch.minimum(torch.Tensor([255]), concatenated_image) + return torch.minimum(Tensor([255]), concatenated_image) def _create_dataset_of_images( @@ -228,7 +231,7 @@ def _create_dataset_of_images( min_overlap: float, max_overlap: float, dims: Tuple, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor]: images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) labels = [] for n in range(num_samples): diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 7278eb2..263bf8e 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -1,4 +1,8 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +"""Class for loading the IAM dataset. + +Which encompasses both paragraphs and lines, with associated utilities. +""" + import os from pathlib import Path from typing import Any, Dict, List @@ -25,21 +29,25 @@ LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. @attr.s(auto_attribs=True) class IAM(BaseDataModule): - """ - "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, - which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels. - From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + r"""The IAM Lines dataset. + + First published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray + levels. From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database The data split we will use is - IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text + lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + The text lines of all data sets are mutually exclusive, thus each writer has + contributed to one set only. """ metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME)) def prepare_data(self) -> None: + """Prepares the IAM dataset.""" if self.xml_filenames: return filename = download_dataset(self.metadata, DL_DATA_DIRNAME) @@ -47,18 +55,22 @@ class IAM(BaseDataModule): @property def xml_filenames(self) -> List[Path]: + """Returns the xml filenames.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @property def form_filenames(self) -> List[Path]: + """Returns the form filenames.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def form_filenames_by_id(self) -> Dict[str, Path]: + """Returns dictionary with filename and path.""" return {filename.stem: filename for filename in self.form_filenames} @property def split_by_id(self) -> Dict[str, str]: + """Splits files into train and test.""" return { filename.stem: "test" if filename.stem in self.metadata["test_ids"] @@ -76,7 +88,7 @@ class IAM(BaseDataModule): @cachedproperty def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: - """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" + """Return a dict from name IAM form to list of (x1, x2, y1, y2).""" return { filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames @@ -129,4 +141,5 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: def download_iam() -> None: + """Downloads and prints IAM dataset.""" load_and_print_info(IAM) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index b9cf90d..f253427 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -7,20 +7,11 @@ from loguru import logger as log import numpy as np from PIL import Image +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, ) -from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.iam_paragraphs import ( - get_dataset_properties, - get_transform, - get_target_transform, - NEW_LINE_TOKEN, - IAMParagraphs, - IMAGE_SCALE_FACTOR, - resize_image, -) from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( @@ -28,6 +19,15 @@ from text_recognizer.data.iam_lines import ( load_line_crops_and_labels, save_images_and_labels, ) +from text_recognizer.data.iam_paragraphs import ( + get_dataset_properties, + get_target_transform, + get_transform, + IAMParagraphs, + IMAGE_SCALE_FACTOR, + NEW_LINE_TOKEN, + resize_image, +) PROCESSED_DATA_DIRNAME = ( @@ -146,7 +146,10 @@ def generate_synthetic_paragraphs( ) if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: log.info( - "Label longer than longest label in original IAM paragraph dataset - hence dropping." + ( + "Label longer than longest label in original IAM paragraph dataset" + " - hence dropping." + ) ) continue @@ -160,7 +163,10 @@ def generate_synthetic_paragraphs( or paragraph_crop.width > max_paragraph_shape[1] ): log.info( - "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping" + ( + "Crop larger than largest crop in original IAM paragraphs dataset" + " - hence dropping" + ) ) continue @@ -213,4 +219,5 @@ def generate_random_batches( def create_synthetic_iam_paragraphs() -> None: + """Creates and prints IAM Synthetic Paragraphs dataset.""" load_and_print_info(IAMSyntheticParagraphs) diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py index afcdbe9..8567e6d 100644 --- a/text_recognizer/data/sentence_generator.py +++ b/text_recognizer/data/sentence_generator.py @@ -28,7 +28,7 @@ class SentenceGenerator: r"""Generates a word or sentences from the Brown corpus. Sample a string from the Brown corpus of length at least one word and at most - max_length, padding to max_length with the '_' characters if sentence is + max_length, padding to max_length with the '_' characters if sentence is shorter. Args: @@ -39,8 +39,11 @@ class SentenceGenerator: str: A sentence from the Brown corpus. Raises: - ValueError: If max_length was not specified at initialization and not + ValueError: If max_length was not specified at initialization and not given as an argument. + + RuntimeError: If a valid string was not generated. + """ if max_length is None: max_length = self.max_length diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 51f52de..7f3e0d1 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -37,6 +37,7 @@ class WordPiece: self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: + """Converts Emnist target tensor to Word piece target tensor.""" y = self.mapping.emnist_to_wordpiece_indices(x) if len(y) < self.max_len: pad_len = self.max_len - len(y) diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py index 2f650cd..dc56942 100644 --- a/text_recognizer/data/word_piece_mapping.py +++ b/text_recognizer/data/word_piece_mapping.py @@ -1,9 +1,9 @@ """Word piece mapping.""" from pathlib import Path -from typing import List, Optional, Union, Set +from typing import List, Optional, Set, Union -import torch from loguru import logger as log +import torch from torch import Tensor from text_recognizer.data.emnist_mapping import EmnistMapping @@ -11,6 +11,8 @@ from text_recognizer.data.iam_preprocessor import Preprocessor class WordPieceMapping(EmnistMapping): + """Word piece mapping.""" + def __init__( self, data_dir: Optional[Path] = None, @@ -20,7 +22,7 @@ class WordPieceMapping(EmnistMapping): use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Set[str] = {"", "", "

"}, - extra_symbols: Set[str] = {"\n",}, + extra_symbols: Set[str] = {"\n"}, ) -> None: super().__init__(extra_symbols=extra_symbols) self.data_dir = ( @@ -60,30 +62,37 @@ class WordPieceMapping(EmnistMapping): ) def __len__(self) -> int: + """Return number of word pieces.""" return len(self.wordpiece_processor.tokens) def get_token(self, index: Union[int, Tensor]) -> str: + """Returns token for index.""" if (index := int(index)) <= self.wordpiece_processor.num_tokens: return self.wordpiece_processor.tokens[index] raise KeyError(f"Index ({index}) not in mapping.") def get_index(self, token: str) -> Tensor: + """Returns index of token.""" if token in self.wordpiece_processor.tokens: return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: + """Returns text from indices.""" if isinstance(indices, Tensor): indices = indices.tolist() return self.wordpiece_processor.to_text(indices) def get_indices(self, text: str) -> Tensor: + """Returns indices of text.""" return self.wordpiece_processor.to_index(text) def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: + """Returns word pieces indices from emnist indices.""" text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) def __getitem__(self, x: Union[int, Tensor]) -> str: + """Returns token for word piece index.""" return self.get_token(x) diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 714792e..60c0ef8 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple +from typing import Tuple, Type from torch import nn, Tensor @@ -24,6 +24,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: nn.Module, decoder: Decoder, + pixel_pos_embedding: Type[nn.Module], + token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.input_dims = input_dims @@ -43,11 +45,7 @@ class ConvTransformer(nn.Module): out_channels=self.hidden_dim, kernel_size=1, ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), + pixel_pos_embedding, nn.Flatten(start_dim=2), ) @@ -57,9 +55,8 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - self.token_pos_encoder = PositionalEncoding( - hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate - ) + self.token_pos_embedding = token_pos_embedding + # Head self.head = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes @@ -119,7 +116,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_encoder(trg) + trg = self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] -- cgit v1.2.3-70-g09d2