summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_extended_paragraphs.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/data/iam_extended_paragraphs.py
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/data/iam_extended_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py33
1 files changed, 18 insertions, 15 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0a30a42..886e37e 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,4 +1,7 @@
"""IAM original and sythetic dataset class."""
+from typing import Dict, List
+
+import attr
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_dataset import BaseDataset
@@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
+ train_fraction: float = attr.ib()
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
@@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.mapping = self.iam_paragraphs.mapping
self.inverse_mapping = self.iam_paragraphs.inverse_mapping
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
-
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
self.iam_paragraphs.prepare_data()