diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-02 21:13:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-02 21:13:48 +0200 |
commit | 75801019981492eedf9280cb352eea3d8e99b65f (patch) | |
tree | 6521cc4134459e42591b2375f70acd348741474e /text_recognizer/data/iam_lines.py | |
parent | e5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff) |
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r-- | text_recognizer/data/iam_lines.py | 21 |
1 files changed, 8 insertions, 13 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index b7f3fdd..1c63729 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -2,15 +2,14 @@ If not created, will generate a handwritten lines dataset from the IAM paragraphs dataset. - """ import json from pathlib import Path import random -from typing import Dict, List, Sequence, Tuple +from typing import List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log from PIL import Image, ImageFile, ImageOps import numpy as np from torch import Tensor @@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data import image_utils @@ -48,17 +47,13 @@ class IAMLines(BaseDataModule): ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - def __attrs_post_init__(self) -> None: - # TODO: refactor this - self.mapping, self.inverse_mapping, _ = emnist_mapping() - def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Cropping IAM lines regions...") - iam = IAM() + log.info("Cropping IAM lines regions...") + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") @@ -66,7 +61,7 @@ class IAMLines(BaseDataModule): shapes = np.array([crop.size for crop in crops_train + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] - logger.info("Saving images, labels, and statistics...") + log.info("Saving images, labels, and statistics...") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -91,7 +86,7 @@ class IAMLines(BaseDataModule): raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( - labels_train, self.inverse_mapping, length=self.output_dims[0] + labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) @@ -110,7 +105,7 @@ class IAMLines(BaseDataModule): raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( - labels_test, self.inverse_mapping, length=self.output_dims[0] + labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( x_test, y_test, transform=get_transform(IMAGE_WIDTH) |