diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:59:22 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:59:22 +0200 |
commit | 6adcf85afc71a6f276370c86f32b36b15603c9f5 (patch) | |
tree | 5770ffeb44a0925c9b4e2bdee5ead85d6a052837 /text_recognizer/data | |
parent | 7275523f225703e1e4e3b28582703150afc9af29 (diff) |
Lint emnist lines
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 5298726..d4b2b40 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -11,11 +11,11 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, ) +from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST from text_recognizer.data.sentence_generator import SentenceGenerator @@ -34,7 +34,7 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines @attr.s(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): - """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" augment: bool = attr.ib(default=True) max_length: int = attr.ib(default=128) @@ -46,6 +46,7 @@ class EMNISTLines(BaseDataModule): emnist: EMNIST = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: + """Post init constructor.""" self.emnist = EMNIST(mapping=self.mapping) max_width = ( @@ -77,6 +78,7 @@ class EMNISTLines(BaseDataModule): ) def prepare_data(self) -> None: + """Prepare the dataset.""" if self.data_filename.exists(): return np.random.seed(SEED) @@ -85,6 +87,7 @@ class EMNISTLines(BaseDataModule): self._generate_data("test") def setup(self, stage: str = None) -> None: + """Loads the dataset.""" log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: print(self.data_filename) @@ -260,5 +263,5 @@ def _get_transform(augment: bool = False) -> Callable: def generate_emnist_lines() -> None: - """Generates a synthetic handwritten dataset and displays info,""" + """Generates a synthetic handwritten dataset and displays info.""" load_and_print_info(EMNISTLines) |