diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 | 
| commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
| tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/data/iam_lines.py | |
| parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) | |
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
| -rw-r--r-- | text_recognizer/data/iam_lines.py | 22 | 
1 files changed, 7 insertions, 15 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 9c78a22..e45e5c8 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,8 +7,9 @@ dataset.  import json  from pathlib import Path  import random -from typing import List, Sequence, Tuple +from typing import Dict, List, Sequence, Tuple +import attr  from loguru import logger  from PIL import Image, ImageFile, ImageOps  import numpy as np @@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56  IMAGE_WIDTH = 1024 +@attr.s(auto_attribs=True)  class IAMLines(BaseDataModule):      """IAM handwritten lines dataset.""" -    def __init__( -        self, -        augment: bool = True, -        fraction: float = 0.8, -        batch_size: int = 128, -        num_workers: int = 0, -    ) -> None: -        # TODO: add transforms -        super().__init__(batch_size, num_workers) -        self.augment = augment -        self.fraction = fraction +    augment: bool = attr.ib(default=True) +    fraction: float = attr.ib(default=0.8) + +    def __attrs_post_init__(self) -> None:          self.mapping, self.inverse_mapping, _ = emnist_mapping()          self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)          self.output_dims = (89, 1) -        self.data_train: BaseDataset = None -        self.data_val: BaseDataset = None -        self.data_test: BaseDataset = None      def prepare_data(self) -> None:          """Creates the IAM lines dataset if not existing."""  |