summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 17:40:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 17:40:44 +0100
commit75909723fa2b1f6245d5c5422e4f2e88b8a26052 (patch)
treee60c37d05c724db011d75adf9313d93839d193ac
parentcad676fc423efeafde65f03e4815248f2d357011 (diff)
Able to generate support files for lines datasets.
-rw-r--r--src/text_recognizer/datasets/dataset.py39
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py2
-rw-r--r--src/text_recognizer/datasets/util.py19
-rw-r--r--src/text_recognizer/models/base.py51
-rw-r--r--src/text_recognizer/tests/support/create_emnist_lines_support_files.py30
-rw-r--r--src/text_recognizer/tests/support/create_iam_lines_support_files.py27
-rw-r--r--src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.pngbin0 -> 2301 bytes
-rw-r--r--src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.pngbin0 -> 5424 bytes
-rw-r--r--src/text_recognizer/tests/support/emnist_lines/they<eos>.pngbin0 -> 1391 bytes
-rw-r--r--src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.pngbin0 -> 5170 bytes
-rw-r--r--src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.pngbin0 -> 3617 bytes
-rw-r--r--src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.pngbin0 -> 3923 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt3
-rw-r--r--src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt3
-rw-r--r--src/training/run_experiment.py6
15 files changed, 96 insertions, 84 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 2de7f09..95063bc 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -1,11 +1,12 @@
"""Abstract dataset class."""
-from typing import Callable, Dict, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils import data
from torchvision.transforms import ToTensor
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.datasets.util import EmnistMapper
@@ -16,8 +17,8 @@ class Dataset(data.Dataset):
self,
train: bool,
subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
+ transform: Optional[List[Dict]] = None,
+ target_transform: Optional[List[Dict]] = None,
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
@@ -27,8 +28,8 @@ class Dataset(data.Dataset):
Args:
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
- target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
+ target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
@@ -53,14 +54,34 @@ class Dataset(data.Dataset):
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform if transform is not None else ToTensor()
- self.target_transform = (
- target_transform if target_transform is not None else torch.tensor
- )
+ self.transform = self._configure_transform(transform)
+ self.target_transform = self._configure_target_transform(target_transform)
self._data = None
self._targets = None
+ def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
+ transform_list = []
+ if transform is not None:
+ for t in transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ transform_list.append(getattr(transforms, t_type)(**t_args))
+ else:
+ transform_list.append(ToTensor())
+ return transforms.Compose(transform_list)
+
+ def _configure_target_transform(
+ self, target_transform: List[Dict]
+ ) -> transforms.Compose:
+ target_transform_list = [torch.tensor]
+ if target_transform is not None:
+ for t in target_transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ target_transform_list.append(getattr(transforms, t_type)(**t_args))
+ return transforms.Compose(target_transform_list)
+
@property
def data(self) -> Tensor:
"""The input data."""
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index fdd2fe6..5ae142c 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -36,6 +36,8 @@ class IamLinesDataset(Dataset):
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
) -> None:
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
subsample_fraction=subsample_fraction,
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index d2df8b5..bf5e772 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -12,7 +12,8 @@ import cv2
from loguru import logger
import numpy as np
from PIL import Image
-from torch.utils.data import DataLoader, Dataset
+import torch
+from torch import Tensor
from torchvision.datasets import EMNIST
from tqdm import tqdm
@@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
+def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
labels.sort()
@@ -56,21 +57,21 @@ class EmnistMapper:
self.eos_token = eos_token
self.essentials = self._load_emnist_essentials()
- # Load dataset infromation.
+ # Load dataset information.
self._mapping = dict(self.essentials["mapping"])
self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
- def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
"""Maps the token to emnist character or character index.
If the token is an integer (index), the method will return the Emnist character corresponding to that index.
If the token is a str (Emnist character), the method will return the corresponding index for that character.
Args:
- token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+ token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
Returns:
Union[str, int]: The mapping result.
@@ -79,9 +80,11 @@ class EmnistMapper:
KeyError: If the index or string does not exist in the mapping.
"""
- if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
- token
- ) in self.mapping:
+ if (
+ (isinstance(token, np.uint8) or isinstance(token, int))
+ or torch.is_tensor(token)
+ and int(token) in self.mapping
+ ):
return self.mapping[int(token)]
elif isinstance(token, str) and token in self._inverse_mapping:
return self._inverse_mapping[token]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index a945b41..d394b4c 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -15,8 +15,9 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
-from torchvision.transforms import Compose
+from text_recognizer import datasets
+from text_recognizer import networks
from text_recognizer.datasets import EmnistMapper
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -27,8 +28,8 @@ class Model(ABC):
def __init__(
self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
+ network_fn: str,
+ dataset: str,
network_args: Optional[Dict] = None,
dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
@@ -44,8 +45,8 @@ class Model(ABC):
"""Base class, to be inherited by model for specific type of data.
Args:
- network_fn (Type[nn.Module]): The PyTorch network.
- dataset (Type[Dataset]): A dataset class.
+ network_fn (str): The name of network.
+ dataset (str): The name dataset class.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
@@ -62,13 +63,15 @@ class Model(ABC):
device (Optional[str]): Name of the device to train on. Defaults to None.
"""
+ self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
# Has to be set in subclass.
self._mapper = None
# Placeholder.
self._input_shape = None
- self.dataset = dataset
+ self.dataset_name = dataset
+ self.dataset = None
self.dataset_args = dataset_args
# Placeholders for datasets.
@@ -92,10 +95,6 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
- self._name = (
- f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}"
- )
-
self._metrics = metrics if metrics is not None else None
# Set the device.
@@ -132,38 +131,12 @@ class Model(ABC):
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
- def _configure_transforms(self) -> None:
- # Load transforms.
- transforms_module = importlib.import_module(
- "text_recognizer.datasets.transforms"
- )
- if (
- "transform" in self.dataset_args["args"]
- and self.dataset_args["args"]["transform"] is not None
- ):
- transform_ = []
- for t in self.dataset_args["args"]["transform"]:
- args = t["args"] or {}
- transform_.append(getattr(transforms_module, t["type"])(**args))
- self.dataset_args["args"]["transform"] = Compose(transform_)
-
- if (
- "target_transform" in self.dataset_args["args"]
- and self.dataset_args["args"]["target_transform"] is not None
- ):
- target_transform_ = [
- torch.tensor,
- ]
- for t in self.dataset_args["args"]["target_transform"]:
- args = t["args"] or {}
- target_transform_.append(getattr(transforms_module, t["type"])(**args))
- self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
-
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
- self._configure_transforms()
+ # Load dataset module.
+ self.dataset = getattr(datasets, self.dataset_name)
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
@@ -222,6 +195,8 @@ class Model(ABC):
def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
# If no network arguments are given, load pretrained weights if they exist.
+ # Load network module.
+ network_fn = getattr(networks, network_fn)
if self._network_args is None:
self.load_weights(network_fn)
else:
diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
index 4496e40..9abe143 100644
--- a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
@@ -1,4 +1,6 @@
"""Module for creating EMNIST Lines test support files."""
+# flake8: noqa: S106
+
from pathlib import Path
import shutil
@@ -17,23 +19,31 @@ def create_emnist_lines_support_files() -> None:
SUPPORT_DIRNAME.mkdir()
# TODO: maybe have to add args to dataset.
- dataset = EmnistLinesDataset()
+ dataset = EmnistLinesDataset(
+ init_token="<sos>",
+ pad_token="_",
+ eos_token="<eos>",
+ transform=[{"type": "ToTensor", "args": {}}],
+ target_transform=[
+ {
+ "type": "AddTokens",
+ "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"},
+ }
+ ],
+ ) # nosec: S106
dataset.load_or_generate_data()
- for index in [0, 1, 3]:
+ for index in [5, 7, 9]:
image, target = dataset[index]
+ if len(image.shape) == 3:
+ image = image.squeeze(0)
print(image.sum(), image.dtype)
- label = (
- "".join(
- dataset.mapper[label]
- for label in np.argmax(target[1:], dim=-1).flatten()
- )
- .stip()
- .strip(dataset.mapper.pad_token)
+ label = "".join(dataset.mapper(label) for label in target[1:]).strip(
+ dataset.mapper.pad_token
)
-
print(label)
+ image = image.numpy()
util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/src/text_recognizer/tests/support/create_iam_lines_support_files.py
index bb568ee..50f9e3d 100644
--- a/src/text_recognizer/tests/support/create_iam_lines_support_files.py
+++ b/src/text_recognizer/tests/support/create_iam_lines_support_files.py
@@ -1,4 +1,5 @@
"""Module for creating IAM Lines test support files."""
+# flake8: noqa
from pathlib import Path
import shutil
@@ -17,23 +18,31 @@ def create_emnist_lines_support_files() -> None:
SUPPORT_DIRNAME.mkdir()
# TODO: maybe have to add args to dataset.
- dataset = IamLinesDataset()
+ dataset = IamLinesDataset(
+ init_token="<sos>",
+ pad_token="_",
+ eos_token="<eos>",
+ transform=[{"type": "ToTensor", "args": {}}],
+ target_transform=[
+ {
+ "type": "AddTokens",
+ "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"},
+ }
+ ],
+ )
dataset.load_or_generate_data()
for index in [0, 1, 3]:
image, target = dataset[index]
+ if len(image.shape) == 3:
+ image = image.squeeze(0)
print(image.sum(), image.dtype)
- label = (
- "".join(
- dataset.mapper[label]
- for label in np.argmax(target[1:], dim=-1).flatten()
- )
- .stip()
- .strip(dataset.mapper.pad_token)
+ label = "".join(dataset.mapper(label) for label in target[1:]).strip(
+ dataset.mapper.pad_token
)
-
print(label)
+ image = image.numpy()
util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
diff --git a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
new file mode 100644
index 0000000..b7d0618
--- /dev/null
+++ b/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
new file mode 100644
index 0000000..14a8cf3
--- /dev/null
+++ b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png b/src/text_recognizer/tests/support/emnist_lines/they<eos>.png
new file mode 100644
index 0000000..7f05951
--- /dev/null
+++ b/src/text_recognizer/tests/support/emnist_lines/they<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
new file mode 100644
index 0000000..6eeb642
--- /dev/null
+++ b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
new file mode 100644
index 0000000..4974cf8
--- /dev/null
+++ b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
new file mode 100644
index 0000000..a731245
--- /dev/null
+++ b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
deleted file mode 100644
index 04e1952..0000000
--- a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:b7fa37f24951732e68bf2191fd58e9c332848d5800bf47dc1aa1a9c32c63485a
-size 61946486
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
deleted file mode 100644
index 50a6a20..0000000
--- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:f2e458cf749526ffa876709fc76084d8099a490d9ff24f9be36b8b04c037f073
-size 3457858
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index e6ae84c..55a9572 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -74,8 +74,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic
"""Loads all modules and arguments."""
# Load the dataset module.
dataset_args = experiment_config.get("dataset", {})
- datasets_module = importlib.import_module("text_recognizer.datasets")
- dataset_ = getattr(datasets_module, dataset_args["type"])
+ dataset_ = dataset_args["type"]
# Import the model module and model arguments.
models_module = importlib.import_module("text_recognizer.models")
@@ -92,8 +91,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic
)
# Import network module and arguments.
- network_module = importlib.import_module("text_recognizer.networks")
- network_fn_ = getattr(network_module, experiment_config["network"]["type"])
+ network_fn_ = experiment_config["network"]["type"]
network_args = experiment_config["network"].get("args", {})
# Criterion