diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-09 22:30:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-09 22:30:45 +0200 |
commit | 316c6a456c9b9f1964f77e8ba016651405c6f9c0 (patch) | |
tree | d877b1eb429820ccf5bb0a0426358910597d203a /text_recognizer | |
parent | a96fa058827b739238972569f7c559c75ba6514f (diff) |
Remove abstract mapping
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 5 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/mappings/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/mappings/base.py | 37 | ||||
-rw-r--r-- | text_recognizer/data/mappings/emnist.py | 15 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/conformer.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 4 |
10 files changed, 24 insertions, 58 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 77d15e5..28ba775 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from text_recognizer.data.base_dataset import BaseDataset -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping T = TypeVar("T") @@ -24,7 +24,7 @@ class BaseDataModule(LightningDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 062257d..ba1b61c 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -16,7 +16,7 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator @@ -38,7 +38,7 @@ class EMNISTLines(BaseDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -113,7 +113,6 @@ class EMNISTLines(BaseDataModule): """Loads the dataset.""" log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: - print(self.data_filename) with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = torch.LongTensor(f["y_train"][:]) diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0c16181..61bf6a3 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -4,7 +4,7 @@ from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.load_transform import load_transform_from_file @@ -14,7 +14,7 @@ class IAMExtendedParagraphs(BaseDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index c23dec6..cf50b60 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -19,7 +19,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import AbstractMapping, EmnistMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils import image_utils @@ -39,7 +39,7 @@ class IAMLines(BaseDataModule): def __init__( self, - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/text_recognizer/data/mappings/__init__.py b/text_recognizer/data/mappings/__init__.py index 84177cb..635f506 100644 --- a/text_recognizer/data/mappings/__init__.py +++ b/text_recognizer/data/mappings/__init__.py @@ -1,3 +1,2 @@ """Mapping modules.""" -from text_recognizer.data.mappings.base import AbstractMapping from text_recognizer.data.mappings.emnist import EmnistMapping diff --git a/text_recognizer/data/mappings/base.py b/text_recognizer/data/mappings/base.py deleted file mode 100644 index 572ac95..0000000 --- a/text_recognizer/data/mappings/base.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Mapping to and from word pieces.""" -from abc import ABC, abstractmethod -from typing import Dict, List - -from torch import Tensor - - -class AbstractMapping(ABC): - def __init__( - self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int] - ) -> None: - self.input_size = input_size - self.mapping = mapping - self.inverse_mapping = inverse_mapping - - def __len__(self) -> int: - return len(self.mapping) - - @property - def num_classes(self) -> int: - return self.__len__() - - @abstractmethod - def get_token(self, *args, **kwargs) -> str: - ... - - @abstractmethod - def get_index(self, *args, **kwargs) -> Tensor: - ... - - @abstractmethod - def get_text(self, *args, **kwargs) -> str: - ... - - @abstractmethod - def get_indices(self, *args, **kwargs) -> Tensor: - ... diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py index ecd862e..606d200 100644 --- a/text_recognizer/data/mappings/emnist.py +++ b/text_recognizer/data/mappings/emnist.py @@ -6,22 +6,27 @@ from typing import Dict, List, Optional, Sequence, Union, Tuple import torch from torch import Tensor -from text_recognizer.data.mappings.base import AbstractMapping - ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" -class EmnistMapping(AbstractMapping): +class EmnistMapping: """Mapping for EMNIST labels.""" def __init__( - self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True + self, + input_size: List[int], + mapping: List[str], + inverse_mapping: Dict[str, int], + extra_symbols: Optional[Sequence[str]] = None, + lower: bool = True, ) -> None: + self.input_size = input_size + self.mapping = mapping + self.inverse_mapping = inverse_mapping self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = self._load_mapping() if lower: self._to_lower() - super().__init__(self.input_size, self.mapping, self.inverse_mapping) def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 77c5509..886394d 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -10,7 +10,7 @@ from torch import nn from torch import Tensor from torchmetrics import Accuracy -from text_recognizer.data.mappings.base import AbstractMapping +from text_recognizer.data.mappings.base import EmnistMapping class LitBase(LightningModule): @@ -22,7 +22,7 @@ class LitBase(LightningModule): loss_fn: Type[nn.Module], optimizer_configs: DictConfig, lr_scheduler_configs: Optional[DictConfig], - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], ) -> None: super().__init__() diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py index 487eabe..655ebf6 100644 --- a/text_recognizer/models/conformer.py +++ b/text_recognizer/models/conformer.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig import torch from torch import nn, Tensor -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.models.base import LitBase from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.util import first_element @@ -21,7 +21,7 @@ class LitConformer(LitBase): loss_fn: Type[nn.Module], optimizer_configs: DictConfig, lr_scheduler_configs: Optional[DictConfig], - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], max_output_len: int = 451, start_token: str = "<s>", end_token: str = "<e>", diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 4bbc671..1ffff60 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -5,7 +5,7 @@ from omegaconf import DictConfig import torch from torch import nn, Tensor -from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.models.base import LitBase from text_recognizer.models.metrics import CharacterErrorRate @@ -19,7 +19,7 @@ class LitTransformer(LitBase): loss_fn: Type[nn.Module], optimizer_configs: DictConfig, lr_scheduler_configs: Optional[DictConfig], - mapping: Type[AbstractMapping], + mapping: Type[EmnistMapping], max_output_len: int = 451, start_token: str = "<s>", end_token: str = "<e>", |