diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 5 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/mappings/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/mappings/base.py | 37 | ||||
-rw-r--r-- | text_recognizer/data/mappings/emnist.py | 15 |
7 files changed, 18 insertions, 52 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 77d15e5..28ba775 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from text_recognizer.data.base_dataset import BaseDataset -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping T = TypeVar("T") @@ -24,7 +24,7 @@ class BaseDataModule(LightningDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 062257d..ba1b61c 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -16,7 +16,7 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator @@ -38,7 +38,7 @@ class EMNISTLines(BaseDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -113,7 +113,6 @@ class EMNISTLines(BaseDataModule): """Loads the dataset.""" log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: - print(self.data_filename) with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = torch.LongTensor(f["y_train"][:]) diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0c16181..61bf6a3 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -4,7 +4,7 @@ from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.load_transform import load_transform_from_file @@ -14,7 +14,7 @@ class IAMExtendedParagraphs(BaseDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index c23dec6..cf50b60 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -19,7 +19,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import AbstractMapping, EmnistMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils import image_utils @@ -39,7 +39,7 @@ class IAMLines(BaseDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/text_recognizer/data/mappings/__init__.py b/text_recognizer/data/mappings/__init__.py index 84177cb..635f506 100644 --- a/text_recognizer/data/mappings/__init__.py +++ b/text_recognizer/data/mappings/__init__.py @@ -1,3 +1,2 @@ """Mapping modules.""" -from text_recognizer.data.mappings.base import AbstractMapping from text_recognizer.data.mappings.emnist import EmnistMapping diff --git a/text_recognizer/data/mappings/base.py b/text_recognizer/data/mappings/base.py deleted file mode 100644 index 572ac95..0000000 --- a/text_recognizer/data/mappings/base.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Mapping to and from word pieces.""" -from abc import ABC, abstractmethod -from typing import Dict, List - -from torch import Tensor - - -class AbstractMapping(ABC): - def __init__( - self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int] - ) -> None: - self.input_size = input_size - self.mapping = mapping - self.inverse_mapping = inverse_mapping - - def __len__(self) -> int: - return len(self.mapping) - - @property - def num_classes(self) -> int: - return self.__len__() - - @abstractmethod - def get_token(self, *args, **kwargs) -> str: - ... - - @abstractmethod - def get_index(self, *args, **kwargs) -> Tensor: - ... - - @abstractmethod - def get_text(self, *args, **kwargs) -> str: - ... - - @abstractmethod - def get_indices(self, *args, **kwargs) -> Tensor: - ... diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py index ecd862e..606d200 100644 --- a/text_recognizer/data/mappings/emnist.py +++ b/text_recognizer/data/mappings/emnist.py @@ -6,22 +6,27 @@ from typing import Dict, List, Optional, Sequence, Union, Tuple import torch from torch import Tensor -from text_recognizer.data.mappings.base import AbstractMapping - ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" -class EmnistMapping(AbstractMapping): +class EmnistMapping: """Mapping for EMNIST labels.""" def __init__( - self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True + self, + input_size: List[int], + mapping: List[str], + inverse_mapping: Dict[str, int], + extra_symbols: Optional[Sequence[str]] = None, + lower: bool = True, ) -> None: + self.input_size = input_size + self.mapping = mapping + self.inverse_mapping = inverse_mapping self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = self._load_mapping() if lower: self._to_lower() - super().__init__(self.input_size, self.mapping, self.inverse_mapping) def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" |