From 6df941bdf5cad80db38d851dcb23a08a9dc55617 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 11 Jun 2022 23:09:22 +0200
Subject: Fix emnist mapping

---
 text_recognizer/data/base_data_module.py         |  4 ++--
 text_recognizer/data/emnist_lines.py             |  2 +-
 text_recognizer/data/iam_extended_paragraphs.py  |  4 ++--
 text_recognizer/data/iam_lines.py                |  2 +-
 text_recognizer/data/iam_paragraphs.py           | 25 +++++++++++++++++++++---
 text_recognizer/data/iam_synthetic_paragraphs.py | 25 +++++++++++++++++++++---
 text_recognizer/data/mappings/emnist.py          | 13 ++++++------
 text_recognizer/models/base.py                   |  4 ++--
 text_recognizer/models/conformer.py              |  2 +-
 text_recognizer/models/transformer.py            |  2 +-
 10 files changed, 61 insertions(+), 22 deletions(-)

diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 28ba775..69581a3 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,6 +1,6 @@
 """Base lightning DataModule class."""
 from pathlib import Path
-from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
+from typing import Callable, Dict, Optional, Tuple, TypeVar
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -24,7 +24,7 @@ class BaseDataModule(LightningDataModule):
 
     def __init__(
         self,
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
         transform: Optional[Callable] = None,
         test_transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index ba1b61c..2107d74 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -38,7 +38,7 @@ class EMNISTLines(BaseDataModule):
 
     def __init__(
         self,
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
         transform: Optional[Callable] = None,
         test_transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 61bf6a3..90df5f8 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,5 +1,5 @@
 """IAM original and sythetic dataset class."""
-from typing import Callable, Optional, Type
+from typing import Callable, Optional
 from torch.utils.data import ConcatDataset
 
 from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -14,7 +14,7 @@ class IAMExtendedParagraphs(BaseDataModule):
 
     def __init__(
         self,
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
         transform: Optional[Callable] = None,
         test_transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index cf50b60..4899a48 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -39,7 +39,7 @@ class IAMLines(BaseDataModule):
 
     def __init__(
         self,
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
         transform: Optional[Callable] = None,
         test_transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 9c75129..3bf28ff 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -1,7 +1,7 @@
 """IAM Paragraphs Dataset class."""
 import json
 from pathlib import Path
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Tuple
 
 from loguru import logger as log
 import numpy as np
@@ -35,8 +35,27 @@ MAX_WORD_PIECE_LENGTH = 451
 class IAMParagraphs(BaseDataModule):
     """IAM handwriting database paragraphs."""
 
-    def __init__(self) -> None:
-        super().__init__()
+    def __init__(
+        self,
+        mapping: EmnistMapping,
+        transform: Optional[Callable] = None,
+        test_transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        train_fraction: float = 0.8,
+        batch_size: int = 16,
+        num_workers: int = 0,
+        pin_memory: bool = True,
+    ) -> None:
+        super().__init__(
+            mapping,
+            transform,
+            test_transform,
+            target_transform,
+            train_fraction,
+            batch_size,
+            num_workers,
+            pin_memory,
+        )
         self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
         self.output_dims = (MAX_LABEL_LENGTH, 1)
 
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 1dc517d..d51e010 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -1,6 +1,6 @@
 """IAM Synthetic Paragraphs Dataset class."""
 import random
-from typing import Any, List, Sequence, Tuple
+from typing import Any, Callable, List, Optional, Sequence, Tuple
 
 from loguru import logger as log
 import numpy as np
@@ -36,8 +36,27 @@ PROCESSED_DATA_DIRNAME = (
 class IAMSyntheticParagraphs(IAMParagraphs):
     """IAM Handwriting database of synthetic paragraphs."""
 
-    def __init__(self) -> None:
-        super().__init__()
+    def __init__(
+        self,
+        mapping: EmnistMapping,
+        transform: Optional[Callable] = None,
+        test_transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        train_fraction: float = 0.8,
+        batch_size: int = 16,
+        num_workers: int = 0,
+        pin_memory: bool = True,
+    ) -> None:
+        super().__init__(
+            mapping,
+            transform,
+            test_transform,
+            target_transform,
+            train_fraction,
+            batch_size,
+            num_workers,
+            pin_memory,
+        )
 
     def prepare_data(self) -> None:
         """Prepare IAM lines to be used to generate paragraphs."""
diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py
index 606d200..03465a1 100644
--- a/text_recognizer/data/mappings/emnist.py
+++ b/text_recognizer/data/mappings/emnist.py
@@ -14,20 +14,21 @@ class EmnistMapping:
 
     def __init__(
         self,
-        input_size: List[int],
-        mapping: List[str],
-        inverse_mapping: Dict[str, int],
         extra_symbols: Optional[Sequence[str]] = None,
         lower: bool = True,
     ) -> None:
-        self.input_size = input_size
-        self.mapping = mapping
-        self.inverse_mapping = inverse_mapping
         self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
         self.mapping, self.inverse_mapping, self.input_size = self._load_mapping()
         if lower:
             self._to_lower()
 
+    def __len__(self) -> int:
+        return len(self.mapping)
+
+    @property
+    def num_classes(self) -> int:
+        return self.__len__()
+
     def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]:
         """Return the EMNIST mapping."""
         with ESSENTIALS_FILENAME.open() as f:
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 886394d..26cc18c 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -10,7 +10,7 @@ from torch import nn
 from torch import Tensor
 from torchmetrics import Accuracy
 
-from text_recognizer.data.mappings.base import EmnistMapping
+from text_recognizer.data.mappings import EmnistMapping
 
 
 class LitBase(LightningModule):
@@ -22,7 +22,7 @@ class LitBase(LightningModule):
         loss_fn: Type[nn.Module],
         optimizer_configs: DictConfig,
         lr_scheduler_configs: Optional[DictConfig],
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
     ) -> None:
         super().__init__()
 
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py
index 655ebf6..41a9e4d 100644
--- a/text_recognizer/models/conformer.py
+++ b/text_recognizer/models/conformer.py
@@ -21,7 +21,7 @@ class LitConformer(LitBase):
         loss_fn: Type[nn.Module],
         optimizer_configs: DictConfig,
         lr_scheduler_configs: Optional[DictConfig],
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
         max_output_len: int = 451,
         start_token: str = "<s>",
         end_token: str = "<e>",
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 1ffff60..b511947 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -19,7 +19,7 @@ class LitTransformer(LitBase):
         loss_fn: Type[nn.Module],
         optimizer_configs: DictConfig,
         lr_scheduler_configs: Optional[DictConfig],
-        mapping: Type[EmnistMapping],
+        mapping: EmnistMapping,
         max_output_len: int = 451,
         start_token: str = "<s>",
         end_token: str = "<e>",
-- 
cgit v1.2.3-70-g09d2