diff options
| -rw-r--r-- | text_recognizer/data/emnist_lines.py | 9 | 
1 files changed, 6 insertions, 3 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 5298726..d4b2b40 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -11,11 +11,11 @@ import torch  from torchvision import transforms  from torchvision.transforms.functional import InterpolationMode -from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels  from text_recognizer.data.base_data_module import (      BaseDataModule,      load_and_print_info,  ) +from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels  from text_recognizer.data.emnist import EMNIST  from text_recognizer.data.sentence_generator import SentenceGenerator @@ -34,7 +34,7 @@ MAX_OUTPUT_LENGTH = 89  # Same as IAMLines  @attr.s(auto_attribs=True, repr=False)  class EMNISTLines(BaseDataModule): -    """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" +    """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""      augment: bool = attr.ib(default=True)      max_length: int = attr.ib(default=128) @@ -46,6 +46,7 @@ class EMNISTLines(BaseDataModule):      emnist: EMNIST = attr.ib(init=False, default=None)      def __attrs_post_init__(self) -> None: +        """Post init constructor."""          self.emnist = EMNIST(mapping=self.mapping)          max_width = ( @@ -77,6 +78,7 @@ class EMNISTLines(BaseDataModule):          )      def prepare_data(self) -> None: +        """Prepare the dataset."""          if self.data_filename.exists():              return          np.random.seed(SEED) @@ -85,6 +87,7 @@ class EMNISTLines(BaseDataModule):          self._generate_data("test")      def setup(self, stage: str = None) -> None: +        """Loads the dataset."""          log.info("EMNISTLinesDataset loading data from HDF5...")          if stage == "fit" or stage is None:              print(self.data_filename) @@ -260,5 +263,5 @@ def _get_transform(augment: bool = False) -> Callable:  def generate_emnist_lines() -> None: -    """Generates a synthetic handwritten dataset and displays info,""" +    """Generates a synthetic handwritten dataset and displays info."""      load_and_print_info(EMNISTLines)  |