From 6adcf85afc71a6f276370c86f32b36b15603c9f5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:59:22 +0200 Subject: Lint emnist lines --- text_recognizer/data/emnist_lines.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 5298726..d4b2b40 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -11,11 +11,11 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, ) +from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST from text_recognizer.data.sentence_generator import SentenceGenerator @@ -34,7 +34,7 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines @attr.s(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): - """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" augment: bool = attr.ib(default=True) max_length: int = attr.ib(default=128) @@ -46,6 +46,7 @@ class EMNISTLines(BaseDataModule): emnist: EMNIST = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: + """Post init constructor.""" self.emnist = EMNIST(mapping=self.mapping) max_width = ( @@ -77,6 +78,7 @@ class EMNISTLines(BaseDataModule): ) def prepare_data(self) -> None: + """Prepare the dataset.""" if self.data_filename.exists(): return np.random.seed(SEED) @@ -85,6 +87,7 @@ class EMNISTLines(BaseDataModule): self._generate_data("test") def setup(self, stage: str = None) -> None: + """Loads the dataset.""" log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: print(self.data_filename) @@ -260,5 +263,5 @@ def _get_transform(augment: bool = False) -> Callable: def generate_emnist_lines() -> None: - """Generates a synthetic handwritten dataset and displays info,""" + """Generates a synthetic handwritten dataset and displays info.""" load_and_print_info(EMNISTLines) -- cgit v1.2.3-70-g09d2