summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/data/base_data_module.py2
-rw-r--r--text_recognizer/data/base_dataset.py7
-rw-r--r--text_recognizer/data/emnist.py16
-rw-r--r--text_recognizer/data/iam_lines.py27
-rw-r--r--text_recognizer/data/iam_paragraphs.py2
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/data/mappings/emnist.py22
7 files changed, 43 insertions, 35 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 6306cf8..77d15e5 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from text_recognizer.data.base_dataset import BaseDataset
-from text_recognizer.data.mappings.base import AbstractMapping
+from text_recognizer.data.mappings import AbstractMapping
T = TypeVar("T")
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 675683a..4ceb818 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -32,13 +32,6 @@ class BaseDataset(Dataset):
self.targets = targets
self.transform = transform
self.target_transform = target_transform
-
- def __attrs_pre_init__(self) -> None:
- """Pre init constructor."""
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
self.transform = self._load_transform(self.transform)
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 1b1381a..ea27984 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -102,22 +102,6 @@ class EMNIST(BaseDataModule):
return basic + data
-def emnist_mapping(
- extra_symbols: Optional[Set[str]] = None,
-) -> Tuple[List, Dict[str, int], List[int]]:
- """Return the EMNIST mapping."""
- if not ESSENTIALS_FILENAME.exists():
- download_and_process_emnist()
- with ESSENTIALS_FILENAME.open() as f:
- essentials = json.load(f)
- mapping = list(essentials["characters"])
- if extra_symbols is not None:
- mapping += extra_symbols
- inverse_mapping = {v: k for k, v in enumerate(mapping)}
- input_shape = essentials["input_shape"]
- return mapping, inverse_mapping, input_shape
-
-
def download_and_process_emnist() -> None:
"""Downloads and preprocesses EMNIST dataset."""
metadata = toml.load(METADATA_FILENAME)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 5f38f14..c23dec6 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -5,7 +5,7 @@ dataset.
"""
import json
from pathlib import Path
-from typing import List, Sequence, Tuple
+from typing import Callable, List, Optional, Sequence, Tuple, Type
from loguru import logger as log
import numpy as np
@@ -19,7 +19,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings.emnist import EmnistMapping
+from text_recognizer.data.mappings import AbstractMapping, EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils import image_utils
@@ -37,8 +37,27 @@ MAX_WORD_PIECE_LENGTH = 72
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- def __init__(self) -> None:
- super().__init__()
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ ) -> None:
+ super().__init__(
+ mapping,
+ transform,
+ test_transform,
+ target_transform,
+ train_fraction,
+ batch_size,
+ num_workers,
+ pin_memory,
+ )
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (MAX_LABEL_LENGTH, 1)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index dde505d..9c75129 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -16,7 +16,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings.emnist import EmnistMapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 2e7762c..1dc517d 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -24,7 +24,7 @@ from text_recognizer.data.iam_paragraphs import (
NEW_LINE_TOKEN,
resize_image,
)
-from text_recognizer.data.mappings.emnist import EmnistMapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py
index 51e4677..ecd862e 100644
--- a/text_recognizer/data/mappings/emnist.py
+++ b/text_recognizer/data/mappings/emnist.py
@@ -1,12 +1,15 @@
"""Emnist mapping."""
-from typing import List, Optional, Sequence, Union
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Sequence, Union, Tuple
import torch
from torch import Tensor
-from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.mappings.base import AbstractMapping
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+
class EmnistMapping(AbstractMapping):
"""Mapping for EMNIST labels."""
@@ -15,13 +18,22 @@ class EmnistMapping(AbstractMapping):
self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True
) -> None:
self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
- self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- self.extra_symbols
- )
+ self.mapping, self.inverse_mapping, self.input_size = self._load_mapping()
if lower:
self._to_lower()
super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+ def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]:
+ """Return the EMNIST mapping."""
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ mapping = list(essentials["characters"])
+ if self.extra_symbols is not None:
+ mapping += self.extra_symbols
+ inverse_mapping = {v: k for k, v in enumerate(mapping)}
+ input_shape = essentials["input_shape"]
+ return mapping, inverse_mapping, input_shape
+
def _to_lower(self) -> None:
"""Converts mapping to lowercase letters only."""