summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py22
1 files changed, 7 insertions, 15 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 9c78a22..e45e5c8 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,8 +7,9 @@ dataset.
import json
from pathlib import Path
import random
-from typing import List, Sequence, Tuple
+from typing import Dict, List, Sequence, Tuple
+import attr
from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
@@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+@attr.s(auto_attribs=True)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- def __init__(
- self,
- augment: bool = True,
- fraction: float = 0.8,
- batch_size: int = 128,
- num_workers: int = 0,
- ) -> None:
- # TODO: add transforms
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.fraction = fraction
+ augment: bool = attr.ib(default=True)
+ fraction: float = attr.ib(default=0.8)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping()
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (89, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""