diff options
Diffstat (limited to 'text_recognizer/data')
| -rw-r--r-- | text_recognizer/data/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/data/base_data_module.py | 2 | ||||
| -rw-r--r-- | text_recognizer/data/base_dataset.py | 13 | ||||
| -rw-r--r-- | text_recognizer/data/emnist.py | 37 | ||||
| -rw-r--r-- | text_recognizer/data/emnist_lines.py | 32 | ||||
| -rw-r--r-- | text_recognizer/data/iam.py | 39 | ||||
| -rw-r--r-- | text_recognizer/data/iam_dataset.py | 133 | ||||
| -rw-r--r-- | text_recognizer/data/iam_lines.py | 255 | ||||
| -rw-r--r-- | text_recognizer/data/iam_lines_dataset.py | 110 | ||||
| -rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 1 | ||||
| -rw-r--r-- | text_recognizer/data/image_utils.py | 49 | ||||
| -rw-r--r-- | text_recognizer/data/sentence_generator.py | 7 | ||||
| -rw-r--r-- | text_recognizer/data/transforms.py | 160 | 
13 files changed, 391 insertions, 450 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 2727b20..9a42fa9 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -1 +1,4 @@  """Dataset modules.""" +from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset +from .base_data_module import BaseDataModule, load_and_print_info +from .download_utils import download_dataset diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index f5e7300..8b5c188 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader  def load_and_print_info(data_module_class: type) -> None: -    """Load EMNISTLines and prints info.""" +    """Load dataset and print dataset information."""      dataset = data_module_class()      dataset.prepare_data()      dataset.setup() diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index a9e9c24..d00daaf 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -71,3 +71,16 @@ def convert_strings_to_labels(          for j, token in enumerate(tokens):              labels[i, j] = mapping[token]      return labels + + +def split_dataset( +    dataset: BaseDataset, fraction: float, seed: int +) -> Tuple[BaseDataset, BaseDataset]: +    """Split dataset into two parts with fraction * size and (1 - fraction) * size.""" +    if fraction >= 1.0: +        raise ValueError("Fraction cannot be larger greater or equal to 1.0.") +    split_1 = int(fraction * len(dataset)) +    split_2 = len(dataset) - split_1 +    return torch.utils.data.random_split( +        dataset, [split_1, split_2], generator=torch.Generator().manual_seed(seed) +    ) diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 7f67893..3e10b5f 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,6 +1,6 @@  """EMNIST dataset: downloads it from FSDL aws url if not present."""  from pathlib import Path -from typing import Sequence, Tuple +from typing import Dict, List, Sequence, Tuple  import json  import os  import shutil @@ -10,11 +10,9 @@ import h5py  import numpy as np  from loguru import logger  import toml -import torch -from torch.utils.data import random_split  from torchvision import transforms -from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_dataset import BaseDataset, split_dataset  from text_recognizer.data.base_data_module import (      BaseDataModule,      load_and_print_info, @@ -48,23 +46,18 @@ class EMNIST(BaseDataModule):          self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8      ) -> None:          super().__init__(batch_size, num_workers) -        if not ESSENTIALS_FILENAME.exists(): -            _download_and_process_emnist() -        with ESSENTIALS_FILENAME.open() as f: -            essentials = json.load(f)          self.train_fraction = train_fraction -        self.mapping = list(essentials["characters"]) -        self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} +        self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()          self.data_train = None          self.data_val = None          self.data_test = None          self.transform = transforms.Compose([transforms.ToTensor()]) -        self.dims = (1, *essentials["input_shape"]) +        self.dims = (1, * self.input_shape)          self.output_dims = (1,)      def prepare_data(self) -> None:          if not PROCESSED_DATA_FILENAME.exists(): -            _download_and_process_emnist() +            download_and_process_emnist()      def setup(self, stage: str = None) -> None:          if stage == "fit" or stage is None: @@ -75,10 +68,8 @@ class EMNIST(BaseDataModule):              dataset_train = BaseDataset(                  self.x_train, self.y_train, transform=self.transform              ) -            train_size = int(self.train_fraction * len(dataset_train)) -            val_size = len(dataset_train) - train_size -            self.data_train, self.data_val = random_split( -                dataset_train, [train_size, val_size], generator=torch.Generator() +            self.data_train, self.data_val = split_dataset( +                dataset_train, fraction=self.train_fraction, seed=SEED              )          if stage == "test" or stage is None: @@ -104,7 +95,19 @@ class EMNIST(BaseDataModule):          return basic + data -def _download_and_process_emnist() -> None: +def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]: +    """Return the EMNIST mapping.""" +    if not ESSENTIALS_FILENAME.exists(): +        download_and_process_emnist() +    with ESSENTIALS_FILENAME.open() as f: +        essentials = json.load(f) +    mapping = list(essentials["characters"]) +    inverse_mapping = {v: k for k, v in enumerate(mapping)} +    input_shape = essentials["input_shape"] +    return mapping, inverse_mapping, input_shape + + +def download_and_process_emnist() -> None:      metadata = toml.load(METADATA_FILENAME)      download_dataset(metadata, DL_DATA_DIRNAME)      _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 6c14add..72665d0 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,12 +1,11 @@  """Dataset of generated text from EMNIST characters."""  from collections import defaultdict  from pathlib import Path -from typing import Callable, Dict, Tuple, Sequence +from typing import Callable, Dict, Tuple  import h5py  from loguru import logger  import numpy as np -from PIL import Image  import torch  from torchvision import transforms  from torchvision.transforms.functional import InterpolationMode @@ -58,6 +57,7 @@ class EMNISTLines(BaseDataModule):          self.num_test = num_test          self.emnist = EMNIST() +        # TODO: fix mapping          self.mapping = self.emnist.mapping          max_width = (              int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -66,32 +66,28 @@ class EMNISTLines(BaseDataModule):          if max_width >= IMAGE_WIDTH:              raise ValueError( -                    f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" -                    ) +                f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" +            ) -        self.dims = ( -            self.emnist.dims[0], -            IMAGE_HEIGHT, -            IMAGE_WIDTH -        ) +        self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH)          if self.max_length >= MAX_OUTPUT_LENGTH:              raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")          self.output_dims = (MAX_OUTPUT_LENGTH, 1) -        self.data_train = None -        self.data_val = None -        self.data_test = None +        self.data_train: BaseDataset = None +        self.data_val: BaseDataset = None +        self.data_test: BaseDataset = None      @property      def data_filename(self) -> Path:          """Return name of dataset.""" -        return ( -            DATA_DIRNAME / (f"ml_{self.max_length}_" +        return DATA_DIRNAME / ( +            f"ml_{self.max_length}_"              f"o{self.min_overlap:f}_{self.max_overlap:f}_"              f"ntr{self.num_train}_"              f"ntv{self.num_val}_" -            f"nte{self.num_test}.h5") +            f"nte{self.num_test}.h5"          )      def prepare_data(self) -> None: @@ -144,7 +140,10 @@ class EMNISTLines(BaseDataModule):          x, y = next(iter(self.train_dataloader()))          data = ( -            f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" +            "Train/val/test sizes: " +            f"{len(self.data_train)}, " +            f"{len(self.data_val)}, " +            f"{len(self.data_test)}\n"              f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"              f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"          ) @@ -223,7 +222,6 @@ def _construct_image_from_string(  ) -> torch.Tensor:      overlap = np.random.uniform(min_overlap, max_overlap)      sampled_images = _select_letter_samples_for_string(string, samples_by_char) -    N = len(sampled_images)      H, W = sampled_images[0].shape      next_overlap_width = W - int(overlap * W)      concatenated_image = torch.zeros((H, width), dtype=torch.uint8) diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index fcfe9a7..01272ba 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -60,23 +60,36 @@ class IAM(BaseDataModule):      @property      def split_by_id(self) -> Dict[str, str]: -        return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} +        return { +            filename.stem: "test" +            if filename.stem in self.metadata["test_ids"] +            else "train" +            for filename in self.form_filenames +        }      @cachedproperty      def line_strings_by_id(self) -> Dict[str, List[str]]:          """Return a dict from name of IAM form to list of line texts in it.""" -        return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} +        return { +            filename.stem: _get_line_strings_from_xml_file(filename) +            for filename in self.xml_filenames +        }      @cachedproperty      def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]:          """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" -        return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} +        return { +            filename.stem: _get_line_regions_from_xml_file(filename) +            for filename in self.xml_filenames +        }      def __repr__(self) -> str:          """Return info about the dataset.""" -        return ("IAM Dataset\n" -                f"Num forms total: {len(self.xml_filenames)}\n" -                f"Num in test set: {len(self.metadata['test_ids'])}\n") +        return ( +            "IAM Dataset\n" +            f"Num forms total: {len(self.xml_filenames)}\n" +            f"Num in test set: {len(self.metadata['test_ids'])}\n" +        )  def _extract_raw_dataset(filename: Path, dirname: Path) -> None: @@ -92,7 +105,7 @@ def _get_line_strings_from_xml_file(filename: str) -> List[str]:      """Get the text content of each line. Note that we replace ": with "."""      xml_root_element = ElementTree.parse(filename).getroot()  # nosec      xml_line_elements = xml_root_element.findall("handwritten-part/line") -    return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] +    return [el.attrib["text"].replace(""", '"') for el in xml_line_elements]  def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: @@ -107,13 +120,13 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:      x1s = [int(el.attrib["x"]) for el in word_elements]      y1s = [int(el.attrib["y"]) for el in word_elements]      x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] -    y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] +    y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]      return { -            "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, -            "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, -            "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, -            "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, -            } +        "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, +        "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, +        "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, +        "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, +    }  def download_iam() -> None: diff --git a/text_recognizer/data/iam_dataset.py b/text_recognizer/data/iam_dataset.py deleted file mode 100644 index a8998b9..0000000 --- a/text_recognizer/data/iam_dataset.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" -import os -from typing import Any, Dict, List -import zipfile - -from boltons.cacheutils import cachedproperty -import defusedxml.ElementTree as ET -from loguru import logger -import toml - -from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME - -RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" -RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - -DOWNSAMPLE_FACTOR = 2  # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 0  # Add this many pixels around the exact coordinates. - - -class IamDataset: -    """IAM dataset. - -    "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, -    which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." -    From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database - -    The data split we will use is -    IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. -        The validation set has been merged into the train set. -        The train set has 7,101 lines from 326 writers. -        The test set has 1,861 lines from 128 writers. -        The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. - -    """ - -    def __init__(self) -> None: -        self.metadata = toml.load(METADATA_FILENAME) - -    def load_or_generate_data(self) -> None: -        """Downloads IAM dataset if xml files does not exist.""" -        if not self.xml_filenames: -            self._download_iam() - -    @property -    def xml_filenames(self) -> List: -        """List of xml filenames.""" -        return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) - -    @property -    def form_filenames(self) -> List: -        """List of forms filenames.""" -        return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) - -    def _download_iam(self) -> None: -        curdir = os.getcwd() -        os.chdir(RAW_DATA_DIRNAME) -        _download_raw_dataset(self.metadata) -        _extract_raw_dataset(self.metadata) -        os.chdir(curdir) - -    @property -    def form_filenames_by_id(self) -> Dict: -        """Creates a dictionary with filenames as keys and forms as values.""" -        return {filename.stem: filename for filename in self.form_filenames} - -    @cachedproperty -    def line_strings_by_id(self) -> Dict: -        """Return a dict from name of IAM form to a list of line texts in it.""" -        return { -            filename.stem: _get_line_strings_from_xml_file(filename) -            for filename in self.xml_filenames -        } - -    @cachedproperty -    def line_regions_by_id(self) -> Dict: -        """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" -        return { -            filename.stem: _get_line_regions_from_xml_file(filename) -            for filename in self.xml_filenames -        } - -    def __repr__(self) -> str: -        """Print info about dataset.""" -        return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" - - -def _extract_raw_dataset(metadata: Dict) -> None: -    logger.info("Extracting IAM data.") -    with zipfile.ZipFile(metadata["filename"], "r") as zip_file: -        zip_file.extractall() - - -def _get_line_strings_from_xml_file(filename: str) -> List[str]: -    """Get the text content of each line. Note that we replace " with ".""" -    xml_root_element = ET.parse(filename).getroot()  # nosec -    xml_line_elements = xml_root_element.findall("handwritten-part/line") -    return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] - - -def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: -    """Get the line region dict for each line.""" -    xml_root_element = ET.parse(filename).getroot()  # nosec -    xml_line_elements = xml_root_element.findall("handwritten-part/line") -    return [_get_line_region_from_xml_element(el) for el in xml_line_elements] - - -def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: -    """Extracts coordinates for each line of text.""" -    # TODO: fix input! -    word_elements = xml_line.findall("word/cmp") -    x1s = [int(el.attrib["x"]) for el in word_elements] -    y1s = [int(el.attrib["y"]) for el in word_elements] -    x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] -    y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] -    return { -        "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, -        "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, -        "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, -        "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, -    } - - -def main() -> None: -    """Initializes the dataset and print info about the dataset.""" -    dataset = IamDataset() -    dataset.load_or_generate_data() -    print(dataset) - - -if __name__ == "__main__": -    main() diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py new file mode 100644 index 0000000..391075a --- /dev/null +++ b/text_recognizer/data/iam_lines.py @@ -0,0 +1,255 @@ +"""Class for IAM Lines dataset. + +If not created, will generate a handwritten lines dataset from the IAM paragraphs +dataset. + +""" +import json +from pathlib import Path +import random +from typing import List, Sequence, Tuple + +from loguru import logger +from PIL import Image, ImageFile, ImageOps +import numpy as np +from torch import Tensor +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +from text_recognizer.data.base_dataset import ( +    BaseDataset, +    convert_strings_to_labels, +    split_dataset, +) +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.iam import IAM +from text_recognizer.data import image_utils + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +SEED = 4711 +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 + + +class IAMLines(BaseDataModule): +    """IAM handwritten lines dataset.""" + +    def __init__( +        self, +        augment: bool = True, +        fraction: float = 0.8, +        batch_size: int = 128, +        num_workers: int = 0, +    ) -> None: +        # TODO: add transforms +        super().__init__(batch_size, num_workers) +        self.augment = augment +        self.fraction = fraction +        self.mapping, self.inverse_mapping, _ = emnist_mapping() +        self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) +        self.output_dims = (89, 1) +        self.data_train: BaseDataset = None +        self.data_val: BaseDataset = None +        self.data_test: BaseDataset = None + +    def prepare_data(self) -> None: +        """Creates the IAM lines dataset if not existing.""" +        if PROCESSED_DATA_DIRNAME.exists(): +            return + +        logger.info("Cropping IAM lines regions...") +        iam = IAM() +        iam.prepare_data() +        crops_train, labels_train = line_crops_and_labels(iam, "train") +        crops_test, labels_test = line_crops_and_labels(iam, "test") + +        shapes = np.array([crop.size for crop in crops_train + crops_test]) +        aspect_ratios = shapes[:, 0] / shapes[:, 1] + +        logger.info("Saving images, labels, and statistics...") +        save_images_and_labels( +            crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME +        ) +        save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) + +        with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f: +            f.write(str(aspect_ratios.max())) + +    def setup(self, stage: str = None) -> None: +        """Load data for training/testing.""" +        with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f: +            max_aspect_ratio = float(f.read()) +            image_width = int(IMAGE_HEIGHT * max_aspect_ratio) +            if image_width >= IMAGE_WIDTH: +                raise ValueError("image_width equal or greater than IMAGE_WIDTH") + +        if stage == "fit" or stage is None: +            x_train, labels_train = load_line_crops_and_labels( +                "train", PROCESSED_DATA_DIRNAME +            ) +            if self.output_dims[0] < max([len(l) for l in labels_train]) + 2: +                raise ValueError("Target length longer than max output length.") + +            y_train = convert_strings_to_labels( +                labels_train, self.inverse_mapping, length=self.output_dims[0] +            ) +            data_train = BaseDataset( +                x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) +            ) + +            self.data_train, self.data_val = split_dataset( +                dataset=data_train, fraction=self.fraction, seed=SEED +            ) + +        if stage == "test" or stage is None: +            x_test, labels_test = load_line_crops_and_labels( +                "test", PROCESSED_DATA_DIRNAME +            ) + +            if self.output_dims[0] < max([len(l) for l in labels_test]) + 2: +                raise ValueError("Taget length longer than max output length.") + +            y_test = convert_strings_to_labels( +                labels_test, self.inverse_mapping, length=self.output_dims[0] +            ) +            self.data_test = BaseDataset( +                x_test, y_test, transform=get_transform(IMAGE_WIDTH) +            ) + +        if stage is None: +            self._verify_output_dims(labels_train, labels_test) + +    def _verify_output_dims(self, labels_train: Tensor, labels_test: Tensor) -> None: +        max_label_length = max([len(label) for label in labels_train + labels_test]) + 2 +        output_dims = (max_label_length, 1) +        if output_dims != self.output_dims: +            raise ValueError("Output dim does not match expected output dims.") + +    def __repr__(self) -> str: +        """Return information about the dataset.""" +        basic = ( +            "IAM Lines dataset\n" +            f"Num classes: {len(self.mapping)}\n" +            f"Input dims: {self.dims}\n" +            f"Output dims: {self.output_dims}\n" +        ) + +        if not any([self.data_train, self.data_val, self.data_test]): +            return basic + +        x, y = next(iter(self.train_dataloader())) +        xt, yt = next(iter(self.test_dataloader())) +        data = ( +            "Train/val/test sizes: " +            f"{len(self.data_train)}, " +            f"{len(self.data_val)}, " +            f"{len(self.data_test)}\n" +            f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" +            f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" +            f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" +            f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" +        ) +        return basic + data + + +def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]: +    """Load IAM line labels and regions, and load image crops.""" +    crops = [] +    labels = [] +    for filename in iam.form_filenames: +        if not iam.split_by_id[filename.stem] == split: +            continue +        image = image_utils.read_image_pil(filename) +        image = ImageOps.grayscale(image) +        image = ImageOps.invert(image) +        labels += iam.line_strings_by_id[filename.stem] +        crops += [ +            image.crop([region[box] for box in ["x1", "y1", "x2", "y2"]]) +            for region in iam.line_regions_by_id[filename.stem] +        ] +    if len(crops) != len(labels): +        raise ValueError("Length of crops does not match length of labels") +    return crops, labels + + +def save_images_and_labels( +    crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path +) -> None: +    (data_dirname / split).mkdir(parents=True, exist_ok=True) + +    with (data_dirname / split / "_labels.json").open(mode="w") as f: +        json.dump(labels, f) + +    for index, crop in enumerate(crops): +        crop.save(data_dirname / split / f"{index}.png") + + +def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, List]: +    """Load line crops and labels for given split from processed directoru.""" +    with (data_dirname / split / "_labels.json").open(mode="r") as f: +        labels = json.load(f) + +    crop_filenames = sorted( +        (data_dirname / split).glob("*.png"), +        key=lambda filename: int(Path(filename).stem), +    ) +    crops = [ +        image_utils.read_image_pil(filename, grayscale=True) +        for filename in crop_filenames +    ] + +    if len(crops) != len(labels): +        raise ValueError("Length of crops does not match length of labels") + +    return crops, labels + + +def get_transform(image_width: int, augment: bool = False) -> transforms.Compose: +    """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" + +    def embed_crop(crop: Image, augment: bool = augment, image_width: int = image_width) -> Image: +        # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) +        image = Image.new("L", (image_width, IMAGE_HEIGHT)) + +        # Resize crop. +        crop_width, crop_height = crop.size +        new_crop_height = IMAGE_HEIGHT +        new_crop_width = int(new_crop_height * (crop_width / crop_height)) + +        if augment: +            # Add random stretching +            new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) +            new_crop_width = min(new_crop_width, image_width) +        crop_resized = crop.resize( +            (new_crop_width, new_crop_height), resample=Image.BILINEAR +        ) + +        # Embed in image +        x = min(28, image_width - new_crop_width) +        y = IMAGE_HEIGHT - new_crop_height +        image.paste(crop_resized, (x, y)) + +        return image + +    transfroms_list = [transforms.Lambda(embed_crop)] + +    if augment: +        transfroms_list += [ +            transforms.ColorJitter(brightness=(0.8, 1.6)), +            transforms.RandomAffine( +                degrees=1, +                shear=(-30, 20), +                interpolation=InterpolationMode.BILINEAR, +                fill=0, +            ), +        ] +    transfroms_list.append(transforms.ToTensor()) +    return transforms.Compose(transfroms_list) + + +def generate_iam_lines() -> None: +    load_and_print_info(IAMLines) diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py deleted file mode 100644 index 1cb84bd..0000000 --- a/text_recognizer/data/iam_lines_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -"""IamLinesDataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import h5py -from loguru import logger -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.util import ( -    compute_sha256, -    DATA_DIRNAME, -    download_url, -    EmnistMapper, -) - - -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" -PROCESSED_DATA_URL = ( -    "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" -) - - -class IamLinesDataset(Dataset): -    """IAM lines datasets for handwritten text lines.""" - -    def __init__( -        self, -        train: bool = False, -        subsample_fraction: float = None, -        transform: Optional[Callable] = None, -        target_transform: Optional[Callable] = None, -        init_token: Optional[str] = None, -        pad_token: Optional[str] = None, -        eos_token: Optional[str] = None, -        lower: bool = False, -    ) -> None: -        self.pad_token = "_" if pad_token is None else pad_token - -        super().__init__( -            train=train, -            subsample_fraction=subsample_fraction, -            transform=transform, -            target_transform=target_transform, -            init_token=init_token, -            pad_token=pad_token, -            eos_token=eos_token, -            lower=lower, -        ) - -    @property -    def input_shape(self) -> Tuple: -        """Input shape of the data.""" -        return self.data.shape[1:] if self.data is not None else None - -    @property -    def output_shape(self) -> Tuple: -        """Output shape of the data.""" -        return ( -            self.targets.shape[1:] + (self.num_classes,) -            if self.targets is not None -            else None -        ) - -    def load_or_generate_data(self) -> None: -        """Load or generate dataset data.""" -        if not PROCESSED_DATA_FILENAME.exists(): -            PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) -            logger.info("Downloading IAM lines...") -            download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) -        with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: -            self._data = f[f"x_{self.split}"][:] -            self._targets = f[f"y_{self.split}"][:] -        self._subsample() - -    def __repr__(self) -> str: -        """Print info about the dataset.""" -        return ( -            "IAM Lines Dataset\n"  # pylint: disable=no-member -            f"Number classes: {self.num_classes}\n" -            f"Mapping: {self.mapper.mapping}\n" -            f"Data: {self.data.shape}\n" -            f"Targets: {self.targets.shape}\n" -        ) - -    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: -        """Fetches data, target pair of the dataset for a given and index or indices. - -        Args: -            index (Union[int, Tensor]): Either a list or int of indices/index. - -        Returns: -            Tuple[Tensor, Tensor]: Data target pair. - -        """ -        if torch.is_tensor(index): -            index = index.tolist() - -        data = self.data[index] -        targets = self.targets[index] - -        if self.transform: -            data = self.transform(data) - -        if self.target_transform: -            targets = self.target_transform(targets) - -        return data, targets diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index a93eb00..5d0fad6 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -5,7 +5,6 @@ The code is mostly stolen from:      https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py  """ -  import collections  import itertools  from pathlib import Path diff --git a/text_recognizer/data/image_utils.py b/text_recognizer/data/image_utils.py new file mode 100644 index 0000000..c2b8915 --- /dev/null +++ b/text_recognizer/data/image_utils.py @@ -0,0 +1,49 @@ +"""Image util functions for loading and saving images.""" +from pathlib import Path +from typing import Union +from urllib.request import urlopen + +import cv2 +import numpy as np +from PIL import Image + + +def read_image_pil(image_uri: Union[Path, str], grayscale: bool = False) -> Image: +    """Return PIL image.""" +    image = Image.open(image_uri) +    if grayscale: +        image = image.convert("L") +    return image + + +def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.array: +    """Read image_uri.""" + +    if isinstance(image_uri, str): +        image_uri = Path(image_uri) + +    def read_image_from_filename(image_filename: Path, imread_flag: int) -> np.array: +        return cv2.imread(str(image_filename), imread_flag) + +    def read_image_from_url(image_url: Path, imread_flag: int) -> np.array: +        url_response = urlopen(str(image_url))  # nosec +        image_array = np.array(bytearray(url_response.read()), dtype=np.uint8) +        return cv2.imdecode(image_array, imread_flag) + +    imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR +    image = None + +    if image_uri.exists(): +        image = read_image_from_filename(image_uri, imread_flag) +    else: +        image = read_image_from_url(image_uri, imread_flag) + +    if image is None: +        raise ValueError(f"Could not load image at {image_uri}") + +    return image + + +def write_image(image: np.ndarray, filename: Union[Path, str]) -> None: +    """Write image to file.""" +    cv2.imwrite(str(filename), image) diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py index 53b781c..f09703b 100644 --- a/text_recognizer/data/sentence_generator.py +++ b/text_recognizer/data/sentence_generator.py @@ -1,5 +1,4 @@  """Downloading the Brown corpus with NLTK for sentence generating.""" -  import itertools  import re  import string @@ -9,9 +8,9 @@ import nltk  from nltk.corpus.reader.util import ConcatenatedCorpusView  import numpy as np -from text_recognizer.datasets.util import DATA_DIRNAME +from text_recognizer.data.base_data_module import BaseDataModule -NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" +NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk"  class SentenceGenerator: @@ -47,7 +46,7 @@ class SentenceGenerator:              raise ValueError(                  "Must provide max_length to this method or when making this object."              ) -         +          for _ in range(10):              try:                  index = np.random.randint(0, len(self.word_start_indices) - 1) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index b6a48f5..2291eec 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -1,148 +1,15 @@  """Transforms for PyTorch datasets."""  from abc import abstractmethod  from pathlib import Path -import random  from typing import Any, Optional, Union  from loguru import logger -import numpy as np -from PIL import Image  import torch  from torch import Tensor -import torch.nn.functional as F -from torchvision import transforms -from torchvision.transforms import ( -    ColorJitter, -    Compose, -    Normalize, -    RandomAffine, -    RandomHorizontalFlip, -    RandomRotation, -    ToPILImage, -    ToTensor, -)  from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.datasets.util import EmnistMapper - - -class RandomResizeCrop: -    """Image transform with random resize and crop applied. - -    Stolen from - -    https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - -    """ - -    def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: -        self.jitter = jitter -        self.ratio = ratio - -    def __call__(self, img: np.ndarray) -> np.ndarray: -        """Applies random crop and rotation to an image.""" -        w, h = img.size - -        # pad with white: -        img = transforms.functional.pad(img, self.jitter, fill=255) - -        # crop at random (x, y): -        x = self.jitter + random.randint(-self.jitter, self.jitter) -        y = self.jitter + random.randint(-self.jitter, self.jitter) - -        # randomize aspect ratio: -        size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) -        size = (h, int(size_w)) -        img = transforms.functional.resized_crop(img, y, x, h, w, size) -        return img - - -class Transpose: -    """Transposes the EMNIST image to the correct orientation.""" - -    def __call__(self, image: Image) -> np.ndarray: -        """Swaps axis.""" -        return np.array(image).swapaxes(0, 1) - - -class Resize: -    """Resizes a tensor to a specified width.""" - -    def __init__(self, width: int = 952) -> None: -        # The default is 952 because of the IAM dataset. -        self.width = width - -    def __call__(self, image: Tensor) -> Tensor: -        """Resize tensor in the last dimension.""" -        return F.interpolate(image, size=self.width, mode="nearest") - - -class AddTokens: -    """Adds start of sequence and end of sequence tokens to target tensor.""" - -    def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: -        self.init_token = init_token -        self.pad_token = pad_token -        self.eos_token = eos_token -        if self.init_token is not None: -            self.emnist_mapper = EmnistMapper( -                init_token=self.init_token, -                pad_token=self.pad_token, -                eos_token=self.eos_token, -            ) -        else: -            self.emnist_mapper = EmnistMapper( -                pad_token=self.pad_token, eos_token=self.eos_token, -            ) -        self.pad_value = self.emnist_mapper(self.pad_token) -        self.eos_value = self.emnist_mapper(self.eos_token) - -    def __call__(self, target: Tensor) -> Tensor: -        """Adds a sos token to the begining and a eos token to the end of a target sequence.""" -        dtype, device = target.dtype, target.device - -        # Find the where padding starts. -        pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() - -        target[pad_index] = self.eos_value - -        if self.init_token is not None: -            self.sos_value = self.emnist_mapper(self.init_token) -            sos = torch.tensor([self.sos_value], dtype=dtype, device=device) -            target = torch.cat([sos, target], dim=0) - -        return target - - -class ApplyContrast: -    """Sets everything below a threshold to zero, i.e. increase contrast.""" - -    def __init__(self, low: float = 0.0, high: float = 0.25) -> None: -        self.low = low -        self.high = high - -    def __call__(self, x: Tensor) -> Tensor: -        """Apply mask binary mask to input tensor.""" -        mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) -        return x * mask - - -class Unsqueeze: -    """Add a dimension to the tensor.""" - -    def __call__(self, x: Tensor) -> Tensor: -        """Adds dim.""" -        return x.unsqueeze(0) - - -class Squeeze: -    """Removes the first dimension of a tensor.""" - -    def __call__(self, x: Tensor) -> Tensor: -        """Removes first dim.""" -        return x.squeeze(0) - - +from text_recognizer.data.emnist import emnist_mapping +         class ToLower:      """Converts target to lower case.""" @@ -155,29 +22,14 @@ class ToLower:  class ToCharcters:      """Converts integers to characters.""" -    def __init__( -        self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True -    ) -> None: -        self.init_token = init_token -        self.pad_token = pad_token -        self.eos_token = eos_token -        if self.init_token is not None: -            self.emnist_mapper = EmnistMapper( -                init_token=self.init_token, -                pad_token=self.pad_token, -                eos_token=self.eos_token, -                lower=lower, -            ) -        else: -            self.emnist_mapper = EmnistMapper( -                pad_token=self.pad_token, eos_token=self.eos_token, lower=lower -            ) +    def __init__(self) -> None: +        self.mapping, _, _ = emnist_mapping()       def __call__(self, y: Tensor) -> str:          """Converts a Tensor to a str."""          return ( -            "".join([self.emnist_mapper(int(i)) for i in y]) -            .strip("_") +            "".join([self.mapping(int(i)) for i in y]) +            .strip("<p>")              .replace(" ", "▁")          )  |