From beeaef529e7c893a3475fe27edc880e283373725 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Sun, 8 Nov 2020 12:41:04 +0100
Subject: Trying to get the CNNTransformer to work, but it is hard.

---
 src/text_recognizer/datasets/emnist_dataset.py     |  5 ++++
 .../datasets/emnist_lines_dataset.py               | 35 ++++++++++++++++------
 src/text_recognizer/datasets/transforms.py         | 27 ++++++++++-------
 3 files changed, 48 insertions(+), 19 deletions(-)

(limited to 'src/text_recognizer/datasets')

diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index a8901d6..9884fdf 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -22,6 +22,7 @@ class EmnistDataset(Dataset):
 
     def __init__(
         self,
+        pad_token: str = None,
         train: bool = False,
         sample_to_balance: bool = False,
         subsample_fraction: float = None,
@@ -32,6 +33,7 @@ class EmnistDataset(Dataset):
         """Loads the dataset and the mappings.
 
         Args:
+            pad_token (str): The pad token symbol. Defaults to _.
             train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
             sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
             subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
@@ -45,6 +47,7 @@ class EmnistDataset(Dataset):
             subsample_fraction=subsample_fraction,
             transform=transform,
             target_transform=target_transform,
+            pad_token=pad_token,
         )
 
         self.sample_to_balance = sample_to_balance
@@ -53,6 +56,8 @@ class EmnistDataset(Dataset):
         if transform is None:
             self.transform = Compose([Transpose(), ToTensor()])
 
+        self.target_transform = None
+
         self.seed = seed
 
     def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6091da8..6871492 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -4,6 +4,7 @@ from collections import defaultdict
 from pathlib import Path
 from typing import Callable, Dict, List, Optional, Tuple, Union
 
+import click
 import h5py
 from loguru import logger
 import numpy as np
@@ -58,13 +59,15 @@ class EmnistLinesDataset(Dataset):
             eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
 
         """
+        self.pad_token = "_" if pad_token is None else pad_token
+
         super().__init__(
             train=train,
             transform=transform,
             target_transform=target_transform,
             subsample_fraction=subsample_fraction,
             init_token=init_token,
-            pad_token=pad_token,
+            pad_token=self.pad_token,
             eos_token=eos_token,
         )
 
@@ -127,11 +130,7 @@ class EmnistLinesDataset(Dataset):
     @property
     def data_filename(self) -> Path:
         """Path to the h5 file."""
-        filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
-        if self.train:
-            filename = "train_" + filename
-        else:
-            filename = "test_" + filename
+        filename = "train.pt" if self.train else "test.pt"
         return DATA_DIRNAME / filename
 
     def load_or_generate_data(self) -> None:
@@ -147,8 +146,8 @@ class EmnistLinesDataset(Dataset):
         """Loads the dataset from the h5 file."""
         logger.debug("EmnistLinesDataset loading data from HDF5...")
         with h5py.File(self.data_filename, "r") as f:
-            self._data = f["data"][:]
-            self._targets = f["targets"][:]
+            self._data = f["data"][()]
+            self._targets = f["targets"][()]
 
     def _generate_data(self) -> str:
         """Generates a dataset with the Brown corpus and Emnist characters."""
@@ -157,7 +156,9 @@ class EmnistLinesDataset(Dataset):
         sentence_generator = SentenceGenerator(self.max_length)
 
         # Load emnist dataset.
-        emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+        emnist = EmnistDataset(
+            train=self.train, sample_to_balance=True, pad_token=self.pad_token
+        )
         emnist.load_or_generate_data()
 
         samples_by_character = get_samples_by_character(
@@ -308,6 +309,18 @@ def convert_strings_to_categorical_labels(
     return np.array([[mapping[c] for c in label] for label in labels])
 
 
+@click.command()
+@click.option(
+    "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+    "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+    "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
 def create_datasets(
     max_length: int = 34,
     min_overlap: float = 0,
@@ -326,3 +339,7 @@ def create_datasets(
             num_samples=num,
         )
         emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+    create_datasets()
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index c058972..8deac7f 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,7 +3,7 @@ import numpy as np
 from PIL import Image
 import torch
 from torch import Tensor
-from torchvision.transforms import Compose, ToTensor
+from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
 
 from text_recognizer.datasets.util import EmnistMapper
 
@@ -19,28 +19,35 @@ class Transpose:
 class AddTokens:
     """Adds start of sequence and end of sequence tokens to target tensor."""
 
-    def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None:
+    def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
         self.init_token = init_token
         self.pad_token = pad_token
         self.eos_token = eos_token
-        self.emnist_mapper = EmnistMapper(
-            init_token=self.init_token,
-            pad_token=self.pad_token,
-            eos_token=self.eos_token,
-        )
+        if self.init_token is not None:
+            self.emnist_mapper = EmnistMapper(
+                init_token=self.init_token,
+                pad_token=self.pad_token,
+                eos_token=self.eos_token,
+            )
+        else:
+            self.emnist_mapper = EmnistMapper(
+                pad_token=self.pad_token, eos_token=self.eos_token,
+            )
         self.pad_value = self.emnist_mapper(self.pad_token)
-        self.sos_value = self.emnist_mapper(self.init_token)
         self.eos_value = self.emnist_mapper(self.eos_token)
 
     def __call__(self, target: Tensor) -> Tensor:
         """Adds a sos token to the begining and a eos token to the end of a target sequence."""
         dtype, device = target.dtype, target.device
-        sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
 
         # Find the where padding starts.
         pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
 
         target[pad_index] = self.eos_value
 
-        target = torch.cat([sos, target], dim=0)
+        if self.init_token is not None:
+            self.sos_value = self.emnist_mapper(self.init_token)
+            sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+            target = torch.cat([sos, target], dim=0)
+
         return target
-- 
cgit v1.2.3-70-g09d2