summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 7548ad5..5298726 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,11 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, Tuple
+from typing import Callable, List, Tuple
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import torch
from torchvision import transforms
@@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule):
emnist: EMNIST = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
- self.emnist = EMNIST()
- self.mapping = self.emnist.mapping
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
@@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule):
self._generate_data("test")
def setup(self, stage: str = None) -> None:
- logger.info("EMNISTLinesDataset loading data from HDF5...")
+ log.info("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
print(self.data_filename)
with h5py.File(self.data_filename, "r") as f:
@@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule):
return basic + data
def _generate_data(self, split: str) -> None:
- logger.info(f"EMNISTLines generating data for {split}...")
+ log.info(f"EMNISTLines generating data for {split}...")
sentence_generator = SentenceGenerator(
self.max_length - 2
) # Subtract by 2 because start/end token
@@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule):
if split == "train":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_train
elif split == "val":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_val
else:
samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, emnist.mapping
+ emnist.x_test, emnist.y_test, self.mapping.mapping
)
num = self.num_test
@@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def _get_samples_by_char(
- samples: np.ndarray, labels: np.ndarray, mapping: Dict
+ samples: np.ndarray, labels: np.ndarray, mapping: List
) -> defaultdict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):