summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:59:22 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:59:22 +0200
commit6adcf85afc71a6f276370c86f32b36b15603c9f5 (patch)
tree5770ffeb44a0925c9b4e2bdee5ead85d6a052837 /text_recognizer/data
parent7275523f225703e1e4e3b28582703150afc9af29 (diff)
Lint emnist lines
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/emnist_lines.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 5298726..d4b2b40 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -11,11 +11,11 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.base_data_module import (
BaseDataModule,
load_and_print_info,
)
+from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
from text_recognizer.data.sentence_generator import SentenceGenerator
@@ -34,7 +34,7 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
@attr.s(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
- """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
+ """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
augment: bool = attr.ib(default=True)
max_length: int = attr.ib(default=128)
@@ -46,6 +46,7 @@ class EMNISTLines(BaseDataModule):
emnist: EMNIST = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
+ """Post init constructor."""
self.emnist = EMNIST(mapping=self.mapping)
max_width = (
@@ -77,6 +78,7 @@ class EMNISTLines(BaseDataModule):
)
def prepare_data(self) -> None:
+ """Prepare the dataset."""
if self.data_filename.exists():
return
np.random.seed(SEED)
@@ -85,6 +87,7 @@ class EMNISTLines(BaseDataModule):
self._generate_data("test")
def setup(self, stage: str = None) -> None:
+ """Loads the dataset."""
log.info("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
print(self.data_filename)
@@ -260,5 +263,5 @@ def _get_transform(augment: bool = False) -> Callable:
def generate_emnist_lines() -> None:
- """Generates a synthetic handwritten dataset and displays info,"""
+ """Generates a synthetic handwritten dataset and displays info."""
load_and_print_info(EMNISTLines)