From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- text_recognizer/data/emnist_lines.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) (limited to 'text_recognizer/data/emnist_lines.py') diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 7548ad5..5298726 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,11 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, Tuple +from typing import Callable, List, Tuple import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import torch from torchvision import transforms @@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule): emnist: EMNIST = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: - self.emnist = EMNIST() - self.mapping = self.emnist.mapping + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule): self._generate_data("test") def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") + log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: print(self.data_filename) with h5py.File(self.data_filename, "r") as f: @@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule): return basic + data def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") + log.info(f"EMNISTLines generating data for {split}...") sentence_generator = SentenceGenerator( self.max_length - 2 ) # Subtract by 2 because start/end token @@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule): if split == "train": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_train elif split == "val": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_val else: samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping + emnist.x_test, emnist.y_test, self.mapping.mapping ) num = self.num_test @@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict + samples: np.ndarray, labels: np.ndarray, mapping: List ) -> defaultdict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): -- cgit v1.2.3-70-g09d2