summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/data/emnist_lines.py
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py35
1 files changed, 11 insertions, 24 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 9650198..4747508 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -3,6 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, Tuple
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+@attr.s(auto_attribs=True)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
- def __init__(
- self,
- augment: bool = True,
- batch_size: int = 128,
- num_workers: int = 0,
- max_length: int = 32,
- min_overlap: float = 0.0,
- max_overlap: float = 0.33,
- num_train: int = 10_000,
- num_val: int = 2_000,
- num_test: int = 2_000,
- ) -> None:
- super().__init__(batch_size, num_workers)
-
- self.augment = augment
- self.max_length = max_length
- self.min_overlap = min_overlap
- self.max_overlap = max_overlap
- self.num_train = num_train
- self.num_val = num_val
- self.num_test = num_test
+ augment: bool = attr.ib(default=True)
+ max_length: int = attr.ib(default=128)
+ min_overlap: float = attr.ib(default=0.0)
+ max_overlap: float = attr.ib(default=0.33)
+ num_train: int = attr.ib(default=10_000)
+ num_val: int = attr.ib(default=2_000)
+ num_test: int = attr.ib(default=2_000)
+ emnist: EMNIST = attr.ib(init=False, default=None)
+ def __attrs_post_init__(self) -> None:
self.emnist = EMNIST()
self.mapping = self.emnist.mapping
@@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule):
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
@property
def data_filename(self) -> Path: