summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_dataset.py1
-rw-r--r--text_recognizer/data/emnist.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py23
-rw-r--r--text_recognizer/data/iam_lines.py6
-rw-r--r--text_recognizer/data/iam_paragraphs.py7
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py12
6 files changed, 22 insertions, 29 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 4318dfb..c26f1c9 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -29,6 +29,7 @@ class BaseDataset(Dataset):
super().__init__()
def __attrs_post_init__(self) -> None:
+ # TODO: refactor this
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index d51a42a..2d0ac29 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -46,7 +46,7 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- train_fraction: float = attr.ib()
+ train_fraction: float = attr.ib(default=0.8)
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 886e37e..58c7369 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -13,23 +13,24 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- train_fraction: float = attr.ib()
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- self.batch_size,
- self.num_workers,
- self.train_fraction,
- self.augment,
- self.word_pieces,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ train_fraction=self.train_fraction,
+ augment=self.augment,
+ word_pieces=self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- self.batch_size,
- self.num_workers,
- self.train_fraction,
- self.augment,
- self.word_pieces,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ train_fraction=self.train_fraction,
+ augment=self.augment,
+ word_pieces=self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index e45e5c8..705cfa3 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -34,6 +34,7 @@ SEED = 4711
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+MAX_LABEL_LENGTH = 89
@attr.s(auto_attribs=True)
@@ -42,11 +43,12 @@ class IAMLines(BaseDataModule):
augment: bool = attr.ib(default=True)
fraction: float = attr.ib(default=0.8)
+ dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH))
+ output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
def __attrs_post_init__(self) -> None:
+ # TODO: refactor this
self.mapping, self.inverse_mapping, _ = emnist_mapping()
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (89, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index bdfb490..9977978 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -41,6 +41,8 @@ class IAMParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
+ dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH))
+ output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
@@ -49,11 +51,6 @@ class IAMParagraphs(BaseDataModule):
if self.word_pieces:
self.mapping = WordPieceMapping()
- self.train_fraction = train_fraction
-
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
-
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 00fa2b6..a3697e7 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,6 +2,7 @@
import random
from typing import Any, List, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image
@@ -33,19 +34,10 @@ PROCESSED_DATA_DIRNAME = (
)
+@attr.s(auto_attribs=True)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces)
-
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
if PROCESSED_DATA_DIRNAME.exists():