diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 | 
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 | 
| commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
| tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/datasets | |
| parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) | |
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/datasets')
| -rw-r--r-- | src/text_recognizer/datasets/__init__.py | 12 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 32 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 38 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/iam_dataset.py | 134 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 126 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 322 | ||||
| -rw-r--r-- | src/text_recognizer/datasets/util.py | 99 | 
7 files changed, 685 insertions, 78 deletions
| diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 05f74f6..ede4541 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -10,15 +10,23 @@ from .emnist_lines_dataset import (      EmnistLinesDataset,      get_samples_by_character,  ) -from .util import fetch_data_loaders, Transpose +from .iam_dataset import IamDataset +from .iam_lines_dataset import IamLinesDataset +from .iam_paragraphs_dataset import IamParagraphsDataset +from .util import _download_raw_dataset, compute_sha256, download_url, Transpose  __all__ = [ +    "_download_raw_dataset", +    "compute_sha256",      "construct_image_from_string",      "DATA_DIRNAME", +    "download_url",      "EmnistDataset",      "EmnistMapper",      "EmnistLinesDataset", -    "fetch_data_loaders",      "get_samples_by_character", +    "IamDataset", +    "IamLinesDataset", +    "IamParagraphsDataset",      "Transpose",  ] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 49ebad3..0715aae 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -152,8 +152,7 @@ class EmnistDataset(Dataset):          """Loads the dataset and the mappings.          Args: -            train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to -                False. +            train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.              sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.              subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.              transform (Optional[Callable]): Transform(s) for input data. Defaults to None. @@ -181,17 +180,37 @@ class EmnistDataset(Dataset):          self.seed = seed          self._mapper = EmnistMapper() -        self.input_shape = self._mapper.input_shape +        self._input_shape = self._mapper.input_shape          self.num_classes = self._mapper.num_classes          # Load dataset. -        self.data, self.targets = self.load_emnist_dataset() +        self._data, self._targets = self.load_emnist_dataset() + +    @property +    def data(self) -> Tensor: +        """The input data.""" +        return self._data + +    @property +    def targets(self) -> Tensor: +        """The target data.""" +        return self._targets + +    @property +    def input_shape(self) -> Tuple: +        """Input shape of the data.""" +        return self._input_shape      @property      def mapper(self) -> EmnistMapper:          """Returns the EmnistMapper."""          return self._mapper +    @property +    def inverse_mapping(self) -> Dict: +        """Returns the inverse mapping from character to index.""" +        return self.mapper.inverse_mapping +      def __len__(self) -> int:          """Returns the length of the dataset."""          return len(self.data) @@ -220,11 +239,6 @@ class EmnistDataset(Dataset):          return data, targets -    @property -    def __name__(self) -> str: -        """Returns the name of the dataset.""" -        return "EmnistDataset" -      def __repr__(self) -> str:          """Returns information about the dataset."""          return ( diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index b0617f5..656131a 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,8 +9,8 @@ from loguru import logger  import numpy as np  import torch  from torch import Tensor -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose, Normalize, ToTensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor  from text_recognizer.datasets import (      DATA_DIRNAME, @@ -20,6 +20,7 @@ from text_recognizer.datasets import (  )  from text_recognizer.datasets.sentence_generator import SentenceGenerator  from text_recognizer.datasets.util import Transpose +from text_recognizer.networks import sliding_window  DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset):          self.transform = transform          if self.transform is None: -            self.transform = Compose([ToTensor()]) +            self.transform = ToTensor()          self.target_transform = target_transform          if self.target_transform is None: @@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset):          # Extract dataset information.          self._mapper = EmnistMapper() -        self.input_shape = self._mapper.input_shape +        self._input_shape = self._mapper.input_shape          self.num_classes = self._mapper.num_classes          self.max_length = max_length          self.min_overlap = min_overlap          self.max_overlap = max_overlap          self.num_samples = num_samples -        self.input_shape = ( +        self._input_shape = (              self.input_shape[0],              self.input_shape[1] * self.max_length,          ) @@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset):          # Load dataset.          self._load_or_generate_data() +    @property +    def input_shape(self) -> Tuple: +        """Input shape of the data.""" +        return self._input_shape +      def __len__(self) -> int:          """Returns the length of the dataset."""          return len(self.data) @@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset):          return data, targets -    @property -    def __name__(self) -> str: -        """Returns the name of the dataset.""" -        return "EmnistLinesDataset" -      def __repr__(self) -> str:          """Returns information about the dataset."""          return ( @@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset):          return self._mapper      @property +    def mapping(self) -> Dict: +        """Return EMNIST mapping from index to character.""" +        return self._mapper.mapping + +    @property      def data_filename(self) -> Path:          """Path to the h5 file."""          filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"          if self.train:              filename = "train_" + filename          else: -            filename = "val_" + filename +            filename = "test_" + filename          return DATA_DIRNAME / filename      def _load_or_generate_data(self) -> None: @@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset):              )              targets = convert_strings_to_categorical_labels( -                targets, self.emnist.inverse_mapping +                targets, emnist.inverse_mapping              )              f.create_dataset("data", data=data, dtype="u1", compression="lzf") @@ -322,13 +328,13 @@ def create_datasets(      min_overlap: float = 0,      max_overlap: float = 0.33,      num_train: int = 10000, -    num_val: int = 1000, +    num_test: int = 1000,  ) -> None:      """Creates a training an validation dataset of Emnist lines."""      emnist_train = EmnistDataset(train=True, sample_to_balance=True) -    emnist_val = EmnistDataset(train=False, sample_to_balance=True) -    datasets = [emnist_train, emnist_val] -    num_samples = [num_train, num_val] +    emnist_test = EmnistDataset(train=False, sample_to_balance=True) +    datasets = [emnist_train, emnist_test] +    num_samples = [num_train, num_test]      for num, train, dataset in zip(num_samples, [True, False], datasets):          emnist_lines = EmnistLinesDataset(              train=train, diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py new file mode 100644 index 0000000..5e47350 --- /dev/null +++ b/src/text_recognizer/datasets/iam_dataset.py @@ -0,0 +1,134 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml +from torch.utils.data import Dataset + +from text_recognizer.datasets import DATA_DIRNAME +from text_recognizer.datasets.util import _download_raw_dataset + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" + +DOWNSAMPLE_FACTOR = 2  # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0  # Add this many pixels around the exact coordinates. + + +class IamDataset(Dataset): +    """IAM dataset. + +    "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, +    which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." +    From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + +    The data split we will use is +    IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. +        The validation set has been merged into the train set. +        The train set has 7,101 lines from 326 writers. +        The test set has 1,861 lines from 128 writers. +        The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + +    """ + +    def __init__(self) -> None: +        self.metadata = toml.load(METADATA_FILENAME) + +    def load_or_generate_data(self) -> None: +        """Downloads IAM dataset if xml files does not exist.""" +        if not self.xml_filenames: +            self._download_iam() + +    @property +    def xml_filenames(self) -> List: +        """List of xml filenames.""" +        return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + +    @property +    def form_filenames(self) -> List: +        """List of forms filenames.""" +        return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + +    def _download_iam(self) -> None: +        curdir = os.getcwd() +        os.chdir(RAW_DATA_DIRNAME) +        _download_raw_dataset(self.metadata) +        _extract_raw_dataset(self.metadata) +        os.chdir(curdir) + +    @property +    def form_filenames_by_id(self) -> Dict: +        """Creates a dictionary with filenames as keys and forms as values.""" +        return {filename.stem: filename for filename in self.form_filenames} + +    @cachedproperty +    def line_strings_by_id(self) -> Dict: +        """Return a dict from name of IAM form to a list of line texts in it.""" +        return { +            filename.stem: _get_line_strings_from_xml_file(filename) +            for filename in self.xml_filenames +        } + +    @cachedproperty +    def line_regions_by_id(self) -> Dict: +        """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" +        return { +            filename.stem: _get_line_regions_from_xml_file(filename) +            for filename in self.xml_filenames +        } + +    def __repr__(self) -> str: +        """Print info about dataset.""" +        return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: +    logger.info("Extracting IAM data.") +    with zipfile.ZipFile(metadata["filename"], "r") as zip_file: +        zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: +    """Get the text content of each line. Note that we replace " with ".""" +    xml_root_element = ET.parse(filename).getroot()  # nosec +    xml_line_elements = xml_root_element.findall("handwritten-part/line") +    return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: +    """Get the line region dict for each line.""" +    xml_root_element = ET.parse(filename).getroot()  # nosec +    xml_line_elements = xml_root_element.findall("handwritten-part/line") +    return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: +    """Extracts coordinates for each line of text.""" +    # TODO: fix input! +    word_elements = xml_line.findall("word/cmp") +    x1s = [int(el.attrib["x"]) for el in word_elements] +    y1s = [int(el.attrib["y"]) for el in word_elements] +    x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] +    y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] +    return { +        "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, +        "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, +        "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, +        "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, +    } + + +def main() -> None: +    """Initializes the dataset and print info about the dataset.""" +    dataset = IamDataset() +    dataset.load_or_generate_data() +    print(dataset) + + +if __name__ == "__main__": +    main() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py new file mode 100644 index 0000000..477f500 --- /dev/null +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -0,0 +1,126 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.util import compute_sha256, download_url + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( +    "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): +    """IAM lines datasets for handwritten text lines.""" + +    def __init__( +        self, +        train: bool = False, +        subsample_fraction: float = None, +        transform: Optional[Callable] = None, +        target_transform: Optional[Callable] = None, +    ) -> None: +        self.train = train +        self.split = "train" if self.train else "test" +        self._mapper = EmnistMapper() +        self.num_classes = self.mapper.num_classes + +        # Set transforms. +        self.transform = transform +        if self.transform is None: +            self.transform = ToTensor() + +        self.target_transform = target_transform +        if self.target_transform is None: +            self.target_transform = torch.tensor + +        self.subsample_fraction = subsample_fraction +        self.data = None +        self.targets = None + +    @property +    def mapper(self) -> EmnistMapper: +        """Returns the EmnistMapper.""" +        return self._mapper + +    @property +    def mapping(self) -> Dict: +        """Return EMNIST mapping from index to character.""" +        return self._mapper.mapping + +    @property +    def input_shape(self) -> Tuple: +        """Input shape of the data.""" +        return self.data.shape[1:] + +    @property +    def output_shape(self) -> Tuple: +        """Output shape of the data.""" +        return self.targets.shape[1:] + (self.num_classes,) + +    def __len__(self) -> int: +        """Returns the length of the dataset.""" +        return len(self.data) + +    def load_or_generate_data(self) -> None: +        """Load or generate dataset data.""" +        if not PROCESSED_DATA_FILENAME.exists(): +            PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) +            logger.info("Downloading IAM lines...") +            download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) +        with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: +            self.data = f[f"x_{self.split}"][:] +            self.targets = f[f"y_{self.split}"][:] +        self._subsample() + +    def _subsample(self) -> None: +        """Only a fraction of the data will be loaded.""" +        if self.subsample_fraction is None: +            return + +        num_samples = int(self.data.shape[0] * self.subsample_fraction) +        self.data = self.data[:num_samples] +        self.targets = self.targets[:num_samples] + +    def __repr__(self) -> str: +        """Print info about the dataset.""" +        return ( +            "IAM Lines Dataset\n"  # pylint: disable=no-member +            f"Number classes: {self.num_classes}\n" +            f"Mapping: {self.mapper.mapping}\n" +            f"Data: {self.data.shape}\n" +            f"Targets: {self.targets.shape}\n" +        ) + +    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: +        """Fetches data, target pair of the dataset for a given and index or indices. + +        Args: +            index (Union[int, Tensor]): Either a list or int of indices/index. + +        Returns: +            Tuple[Tensor, Tensor]: Data target pair. + +        """ +        if torch.is_tensor(index): +            index = index.tolist() + +        data = self.data[index] +        targets = self.targets[index] + +        if self.transform: +            data = self.transform(data) + +        if self.target_transform: +            targets = self.target_transform(targets) + +        return data, targets diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py new file mode 100644 index 0000000..d65b346 --- /dev/null +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -0,0 +1,322 @@ +"""IamParagraphsDataset class and functions for data processing.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import compute_sha256, download_url + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50  # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): +    """IAM Paragraphs dataset for paragraphs of handwritten text. + +    TODO: __getitem__, __len__, get_data_target_from_id + +    """ + +    def __init__( +        self, +        train: bool = False, +        subsample_fraction: float = None, +        transform: Optional[Callable] = None, +        target_transform: Optional[Callable] = None, +    ) -> None: + +        # Load Iam dataset. +        self.iam_dataset = IamDataset() + +        self.train = train +        self.split = "train" if self.train else "test" +        self.num_classes = 3 +        self._input_shape = (256, 256) +        self._output_shape = self._input_shape + (self.num_classes,) +        self.subsample_fraction = subsample_fraction + +        # Set transforms. +        self.transform = transform +        if self.transform is None: +            self.transform = ToTensor() + +        self.target_transform = target_transform +        if self.target_transform is None: +            self.target_transform = torch.tensor + +        self._data = None +        self._targets = None +        self._ids = None + +    def __len__(self) -> int: +        """Returns the length of the dataset.""" +        return len(self.data) + +    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: +        """Fetches data, target pair of the dataset for a given and index or indices. + +        Args: +            index (Union[int, Tensor]): Either a list or int of indices/index. + +        Returns: +            Tuple[Tensor, Tensor]: Data target pair. + +        """ +        if torch.is_tensor(index): +            index = index.tolist() + +        data = self.data[index] +        targets = self.targets[index] + +        if self.transform: +            data = self.transform(data) + +        if self.target_transform: +            targets = self.target_transform(targets) + +        return data, targets + +    @property +    def input_shape(self) -> Tuple: +        """Input shape of the data.""" +        return self._input_shape + +    @property +    def output_shape(self) -> Tuple: +        """Output shape of the data.""" +        return self._output_shape + +    @property +    def data(self) -> Tensor: +        """The input data.""" +        return self._data + +    @property +    def targets(self) -> Tensor: +        """The target data.""" +        return self._targets + +    @property +    def ids(self) -> Tensor: +        """Ids of the dataset.""" +        return self._ids + +    def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: +        """Get data target pair from id.""" +        ind = self.ids.index(id_) +        return self.data[ind], self.targets[ind] + +    def load_or_generate_data(self) -> None: +        """Load or generate dataset data.""" +        num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) +        num_targets = len(self.iam_dataset.line_regions_by_id) + +        if num_actual < num_targets - 2: +            self._process_iam_paragraphs() + +        self._data, self._targets, self._ids = _load_iam_paragraphs() +        self._get_random_split() +        self._subsample() + +    def _get_random_split(self) -> None: +        np.random.seed(SEED) +        num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) +        indices = np.random.permutation(self.data.shape[0]) +        train_indices, test_indices = indices[:num_train], indices[num_train:] +        if self.train: +            self._data = self.data[train_indices] +            self._targets = self.targets[train_indices] +        else: +            self._data = self.data[test_indices] +            self._targets = self.targets[test_indices] + +    def _process_iam_paragraphs(self) -> None: +        """Crop the part with the text. + +        For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are +        self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel +        corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line +        """ +        crop_dims = self._decide_on_crop_dims() +        CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) +        DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) +        GT_DIRNAME.mkdir(parents=True, exist_ok=True) +        logger.info( +            f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" +        ) +        for filename in self.iam_dataset.form_filenames: +            id_ = filename.stem +            line_region = self.iam_dataset.line_regions_by_id[id_] +            _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + +    def _decide_on_crop_dims(self) -> Tuple[int, int]: +        """Decide on the dimensions to crop out of the form image. + +        Since image width is larger than a comfortable crop around the longest paragraph, +        we will make the crop a square form factor. +        And since the found dimensions 610x610 are pretty close to 512x512, +        we might as well resize crops and make it exactly that, which lets us +        do all kinds of power-of-2 pooling and upsampling should we choose to. + +        Returns: +            Tuple[int, int]: A tuple of crop dimensions. + +        Raises: +            RuntimeError: When max crop height is larger than max crop width. + +        """ + +        sample_form_filename = self.iam_dataset.form_filenames[0] +        sample_image = util.read_image(sample_form_filename, grayscale=True) +        max_crop_width = sample_image.shape[1] +        max_crop_height = _get_max_paragraph_crop_height( +            self.iam_dataset.line_regions_by_id +        ) +        if not max_crop_height <= max_crop_width: +            raise RuntimeError( +                f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" +            ) + +        crop_dims = (max_crop_width, max_crop_width) +        logger.info( +            f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." +        ) +        logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") +        return crop_dims + +    def _subsample(self) -> None: +        """Only this fraction of the data will be loaded.""" +        if self.subsample_fraction is None: +            return +        num_subsample = int(self.data.shape[0] * self.subsample_fraction) +        self.data = self.data[:num_subsample] +        self.targets = self.targets[:num_subsample] + +    def __repr__(self) -> str: +        """Return info about the dataset.""" +        return ( +            "IAM Paragraph Dataset\n"  # pylint: disable=no-member +            f"Num classes: {self.num_classes}\n" +            f"Data: {self.data.shape}\n" +            f"Targets: {self.targets.shape}\n" +        ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: +    heights = [] +    for regions in line_regions_by_id.values(): +        min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER +        max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER +        height = max_y2 - min_y1 +        heights.append(height) +    return max(heights) + + +def _crop_paragraph_image( +    filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: +    image = util.read_image(filename, grayscale=True) + +    min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER +    max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER +    height = max_y2 - min_y1 +    crop_height = crop_dims[0] +    buffer = (crop_height - height) // 2 + +    # Generate image crop. +    image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) +    try: +        image_crop[buffer : buffer + height] = image[min_y1:max_y2] +    except Exception as e:  # pylint: disable=broad-except +        logger.error(f"Rescued {filename}: {e}") +        return + +    # Generate ground truth. +    gt_image = np.zeros_like(image_crop, dtype=np.uint8) +    for index, region in enumerate(line_regions): +        gt_image[ +            (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), +            region["x1"] : region["x2"], +        ] = (index % 2 + 1) + +    # Generate image for debugging. +    import matplotlib.pyplot as plt + +    cmap = plt.get_cmap("Set1") +    image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) +    for index, region in enumerate(line_regions): +        color = [255 * _ for _ in cmap(index)[:-1]] +        cv2.rectangle( +            image_crop_for_debug, +            (region["x1"], region["y1"] - min_y1 + buffer), +            (region["x2"], region["y2"] - min_y1 + buffer), +            color, +            3, +        ) +    image_crop_for_debug = cv2.resize( +        image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA +    ) +    util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + +    image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) +    util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + +    gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) +    util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: +    logger.info("Loading IAM paragraph crops and ground truth from image files...") +    images = [] +    gt_images = [] +    ids = [] +    for filename in CROPS_DIRNAME.glob("*.jpg"): +        id_ = filename.stem +        image = util.read_image(filename, grayscale=True) +        image = 1.0 - image / 255 + +        gt_filename = GT_DIRNAME / f"{id_}.png" +        gt_image = util.read_image(gt_filename, grayscale=True) + +        images.append(image) +        gt_images.append(gt_image) +        ids.append(id_) +    images = np.array(images).astype(np.float32) +    gt_images = np.array(gt_images).astype(np.uint8) +    ids = np.array(ids) +    return images, gt_images, ids + + +@click.command() +@click.option( +    "--subsample_fraction", +    type=float, +    default=0.0, +    help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: +    """Load dataset and print info.""" +    dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) +    dataset.load_or_generate_data() +    print(dataset) + + +if __name__ == "__main__": +    main() diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 76bd85f..dd16bed 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,10 +1,17 @@  """Util functions for datasets.""" +import hashlib  import importlib -from typing import Callable, Dict, List, Type +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union +from urllib.request import urlopen, urlretrieve +import cv2 +from loguru import logger  import numpy as np  from PIL import Image  from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm  class Transpose: @@ -15,58 +22,48 @@ class Transpose:          return np.array(image).swapaxes(0, 1) -def fetch_data_loaders( -    splits: List[str], -    dataset: str, -    dataset_args: Dict, -    batch_size: int = 128, -    shuffle: bool = False, -    num_workers: int = 0, -    cuda: bool = True, -) -> Dict[str, DataLoader]: -    """Fetches DataLoaders for given split(s) as a dictionary. - -    Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then -    calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. - -    Args: -        splits (List[str]): One or both of the dataset splits "train" and "val". -        dataset (str): The name of the dataset. -        dataset_args (Dict): The dataset arguments. -        batch_size (int): How many samples per batch to load. Defaults to 128. -        shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. -        num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be -            loaded in the main process. Defaults to 0. -        cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning -            them. Defaults to True. - -    Returns: -        Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. +def compute_sha256(filename: Union[Path, str]) -> str: +    """Returns the SHA256 checksum of a file.""" +    with open(filename, "rb") as f: +        return hashlib.sha256(f.read()).hexdigest() -    """ - -    def check_dataset_args(args: Dict, split: str) -> Dict: -        """Adds train flag to the dataset args.""" -        args["train"] = True if split == "train" else False -        return args - -    # Import dataset module. -    datasets_module = importlib.import_module("text_recognizer.datasets") -    dataset_ = getattr(datasets_module, dataset) -    data_loaders = {} +class TqdmUpTo(tqdm): +    """TQDM progress bar when downloading files. -    for split in ["train", "val"]: -        if split in splits: +    From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py -            data_loader = DataLoader( -                dataset=dataset_(**check_dataset_args(dataset_args, split)), -                batch_size=batch_size, -                shuffle=shuffle, -                num_workers=num_workers, -                pin_memory=cuda, -            ) - -            data_loaders[split] = data_loader +    """ -    return data_loaders +    def update_to( +        self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None +    ) -> None: +        """Updates the progress bar. + +        Args: +            blocks (int): Number of blocks transferred so far. Defaults to 1. +            block_size (int): Size of each block, in tqdm units. Defaults to 1. +            total_size (Optional[int]): Total size in tqdm units. Defaults to None. +        """ +        if total_size is not None: +            self.total = total_size  # pylint: disable=attribute-defined-outside-init +        self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: +    """Downloads a file from url to filename, with a progress bar.""" +    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: +        urlretrieve(url, filename, reporthook=t.update_to, data=None)  # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: +    if os.path.exists(metadata["filename"]): +        return +    logger.info(f"Downloading raw dataset from {metadata['url']}...") +    download_url(metadata["url"], metadata["filename"]) +    logger.info("Computing SHA-256...") +    sha256 = compute_sha256(metadata["filename"]) +    if sha256 != metadata["sha256"]: +        raise ValueError( +            "Downloaded data file SHA-256 does not match that listed in metadata document." +        ) |