summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py12
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py32
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py38
-rw-r--r--src/text_recognizer/datasets/iam_dataset.py134
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py126
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py322
-rw-r--r--src/text_recognizer/datasets/util.py99
-rw-r--r--src/text_recognizer/models/__init__.py5
-rw-r--r--src/text_recognizer/models/base.py331
-rw-r--r--src/text_recognizer/models/character_model.py20
-rw-r--r--src/text_recognizer/models/line_ctc_model.py105
-rw-r--r--src/text_recognizer/models/metrics.py80
-rw-r--r--src/text_recognizer/networks/__init__.py17
-rw-r--r--src/text_recognizer/networks/ctc.py66
-rw-r--r--src/text_recognizer/networks/lenet.py12
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py76
-rw-r--r--src/text_recognizer/networks/misc.py5
-rw-r--r--src/text_recognizer/networks/mlp.py6
-rw-r--r--src/text_recognizer/networks/residual_network.py35
-rw-r--r--src/text_recognizer/networks/stn.py44
-rw-r--r--src/text_recognizer/networks/wide_resnet.py214
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.ptbin11625484 -> 17938163 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.ptbin4562821 -> 5003730 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.ptbin0 -> 44089479 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.ptbin0 -> 15375126 bytes
25 files changed, 1537 insertions, 242 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 05f74f6..ede4541 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -10,15 +10,23 @@ from .emnist_lines_dataset import (
EmnistLinesDataset,
get_samples_by_character,
)
-from .util import fetch_data_loaders, Transpose
+from .iam_dataset import IamDataset
+from .iam_lines_dataset import IamLinesDataset
+from .iam_paragraphs_dataset import IamParagraphsDataset
+from .util import _download_raw_dataset, compute_sha256, download_url, Transpose
__all__ = [
+ "_download_raw_dataset",
+ "compute_sha256",
"construct_image_from_string",
"DATA_DIRNAME",
+ "download_url",
"EmnistDataset",
"EmnistMapper",
"EmnistLinesDataset",
- "fetch_data_loaders",
"get_samples_by_character",
+ "IamDataset",
+ "IamLinesDataset",
+ "IamParagraphsDataset",
"Transpose",
]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 49ebad3..0715aae 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -152,8 +152,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to
- False.
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
@@ -181,17 +180,37 @@ class EmnistDataset(Dataset):
self.seed = seed
self._mapper = EmnistMapper()
- self.input_shape = self._mapper.input_shape
+ self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
# Load dataset.
- self.data, self.targets = self.load_emnist_dataset()
+ self._data, self._targets = self.load_emnist_dataset()
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
@property
def mapper(self) -> EmnistMapper:
"""Returns the EmnistMapper."""
return self._mapper
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -220,11 +239,6 @@ class EmnistDataset(Dataset):
return data, targets
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistDataset"
-
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index b0617f5..656131a 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -9,8 +9,8 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import DataLoader, Dataset
-from torchvision.transforms import Compose, Normalize, ToTensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
from text_recognizer.datasets import (
DATA_DIRNAME,
@@ -20,6 +20,7 @@ from text_recognizer.datasets import (
)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
from text_recognizer.datasets.util import Transpose
+from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset):
self.transform = transform
if self.transform is None:
- self.transform = Compose([ToTensor()])
+ self.transform = ToTensor()
self.target_transform = target_transform
if self.target_transform is None:
@@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset):
# Extract dataset information.
self._mapper = EmnistMapper()
- self.input_shape = self._mapper.input_shape
+ self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
self.max_length = max_length
self.min_overlap = min_overlap
self.max_overlap = max_overlap
self.num_samples = num_samples
- self.input_shape = (
+ self._input_shape = (
self.input_shape[0],
self.input_shape[1] * self.max_length,
)
@@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset):
# Load dataset.
self._load_or_generate_data()
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset):
return data, targets
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistLinesDataset"
-
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
@@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset):
return self._mapper
@property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
if self.train:
filename = "train_" + filename
else:
- filename = "val_" + filename
+ filename = "test_" + filename
return DATA_DIRNAME / filename
def _load_or_generate_data(self) -> None:
@@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset):
)
targets = convert_strings_to_categorical_labels(
- targets, self.emnist.inverse_mapping
+ targets, emnist.inverse_mapping
)
f.create_dataset("data", data=data, dtype="u1", compression="lzf")
@@ -322,13 +328,13 @@ def create_datasets(
min_overlap: float = 0,
max_overlap: float = 0.33,
num_train: int = 10000,
- num_val: int = 1000,
+ num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_val = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_val]
- num_samples = [num_train, num_val]
+ emnist_test = EmnistDataset(train=False, sample_to_balance=True)
+ datasets = [emnist_train, emnist_test]
+ num_samples = [num_train, num_test]
for num, train, dataset in zip(num_samples, [True, False], datasets):
emnist_lines = EmnistLinesDataset(
train=train,
diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py
new file mode 100644
index 0000000..5e47350
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_dataset.py
@@ -0,0 +1,134 @@
+"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+import os
+from typing import Any, Dict, List
+import zipfile
+
+from boltons.cacheutils import cachedproperty
+import defusedxml.ElementTree as ET
+from loguru import logger
+import toml
+from torch.utils.data import Dataset
+
+from text_recognizer.datasets import DATA_DIRNAME
+from text_recognizer.datasets.util import _download_raw_dataset
+
+RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
+
+DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
+LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
+
+
+class IamDataset(Dataset):
+ """IAM dataset.
+
+ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
+ From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+
+ The data split we will use is
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ The validation set has been merged into the train set.
+ The train set has 7,101 lines from 326 writers.
+ The test set has 1,861 lines from 128 writers.
+ The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+
+ """
+
+ def __init__(self) -> None:
+ self.metadata = toml.load(METADATA_FILENAME)
+
+ def load_or_generate_data(self) -> None:
+ """Downloads IAM dataset if xml files does not exist."""
+ if not self.xml_filenames:
+ self._download_iam()
+
+ @property
+ def xml_filenames(self) -> List:
+ """List of xml filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+
+ @property
+ def form_filenames(self) -> List:
+ """List of forms filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+
+ def _download_iam(self) -> None:
+ curdir = os.getcwd()
+ os.chdir(RAW_DATA_DIRNAME)
+ _download_raw_dataset(self.metadata)
+ _extract_raw_dataset(self.metadata)
+ os.chdir(curdir)
+
+ @property
+ def form_filenames_by_id(self) -> Dict:
+ """Creates a dictionary with filenames as keys and forms as values."""
+ return {filename.stem: filename for filename in self.form_filenames}
+
+ @cachedproperty
+ def line_strings_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of line texts in it."""
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ @cachedproperty
+ def line_regions_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ def __repr__(self) -> str:
+ """Print info about dataset."""
+ return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
+
+
+def _extract_raw_dataset(metadata: Dict) -> None:
+ logger.info("Extracting IAM data.")
+ with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
+ zip_file.extractall()
+
+
+def _get_line_strings_from_xml_file(filename: str) -> List[str]:
+ """Get the text content of each line. Note that we replace " with "."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [el.attrib["text"].replace(""", '"') for el in xml_line_elements]
+
+
+def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
+ """Get the line region dict for each line."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
+
+
+def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
+ """Extracts coordinates for each line of text."""
+ # TODO: fix input!
+ word_elements = xml_line.findall("word/cmp")
+ x1s = [int(el.attrib["x"]) for el in word_elements]
+ y1s = [int(el.attrib["y"]) for el in word_elements]
+ x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
+ return {
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
+
+
+def main() -> None:
+ """Initializes the dataset and print info about the dataset."""
+ dataset = IamDataset()
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
new file mode 100644
index 0000000..477f500
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -0,0 +1,126 @@
+"""IamLinesDataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.util import compute_sha256, download_url
+
+
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
+PROCESSED_DATA_URL = (
+ "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
+)
+
+
+class IamLinesDataset(Dataset):
+ """IAM lines datasets for handwritten text lines."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ self.train = train
+ self.split = "train" if self.train else "test"
+ self._mapper = EmnistMapper()
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self.subsample_fraction = subsample_fraction
+ self.data = None
+ self.targets = None
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self.data.shape[1:]
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self.targets.shape[1:] + (self.num_classes,)
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ if not PROCESSED_DATA_FILENAME.exists():
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info("Downloading IAM lines...")
+ download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self.data = f[f"x_{self.split}"][:]
+ self.targets = f[f"y_{self.split}"][:]
+ self._subsample()
+
+ def _subsample(self) -> None:
+ """Only a fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+
+ num_samples = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_samples]
+ self.targets = self.targets[:num_samples]
+
+ def __repr__(self) -> str:
+ """Print info about the dataset."""
+ return (
+ "IAM Lines Dataset\n" # pylint: disable=no-member
+ f"Number classes: {self.num_classes}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
new file mode 100644
index 0000000..d65b346
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -0,0 +1,322 @@
+"""IamParagraphsDataset class and functions for data processing."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import cv2
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer import util
+from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.iam_dataset import IamDataset
+from text_recognizer.datasets.util import compute_sha256, download_url
+
+INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
+DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
+CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
+GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
+
+PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
+TEST_FRACTION = 0.2
+SEED = 4711
+
+
+class IamParagraphsDataset(Dataset):
+ """IAM Paragraphs dataset for paragraphs of handwritten text.
+
+ TODO: __getitem__, __len__, get_data_target_from_id
+
+ """
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+
+ # Load Iam dataset.
+ self.iam_dataset = IamDataset()
+
+ self.train = train
+ self.split = "train" if self.train else "test"
+ self.num_classes = 3
+ self._input_shape = (256, 256)
+ self._output_shape = self._input_shape + (self.num_classes,)
+ self.subsample_fraction = subsample_fraction
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self._data = None
+ self._targets = None
+ self._ids = None
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def ids(self) -> Tensor:
+ """Ids of the dataset."""
+ return self._ids
+
+ def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
+ """Get data target pair from id."""
+ ind = self.ids.index(id_)
+ return self.data[ind], self.targets[ind]
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
+ num_targets = len(self.iam_dataset.line_regions_by_id)
+
+ if num_actual < num_targets - 2:
+ self._process_iam_paragraphs()
+
+ self._data, self._targets, self._ids = _load_iam_paragraphs()
+ self._get_random_split()
+ self._subsample()
+
+ def _get_random_split(self) -> None:
+ np.random.seed(SEED)
+ num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
+ indices = np.random.permutation(self.data.shape[0])
+ train_indices, test_indices = indices[:num_train], indices[num_train:]
+ if self.train:
+ self._data = self.data[train_indices]
+ self._targets = self.targets[train_indices]
+ else:
+ self._data = self.data[test_indices]
+ self._targets = self.targets[test_indices]
+
+ def _process_iam_paragraphs(self) -> None:
+ """Crop the part with the text.
+
+ For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
+ self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
+ corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
+ """
+ crop_dims = self._decide_on_crop_dims()
+ CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ GT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info(
+ f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
+ )
+ for filename in self.iam_dataset.form_filenames:
+ id_ = filename.stem
+ line_region = self.iam_dataset.line_regions_by_id[id_]
+ _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
+
+ def _decide_on_crop_dims(self) -> Tuple[int, int]:
+ """Decide on the dimensions to crop out of the form image.
+
+ Since image width is larger than a comfortable crop around the longest paragraph,
+ we will make the crop a square form factor.
+ And since the found dimensions 610x610 are pretty close to 512x512,
+ we might as well resize crops and make it exactly that, which lets us
+ do all kinds of power-of-2 pooling and upsampling should we choose to.
+
+ Returns:
+ Tuple[int, int]: A tuple of crop dimensions.
+
+ Raises:
+ RuntimeError: When max crop height is larger than max crop width.
+
+ """
+
+ sample_form_filename = self.iam_dataset.form_filenames[0]
+ sample_image = util.read_image(sample_form_filename, grayscale=True)
+ max_crop_width = sample_image.shape[1]
+ max_crop_height = _get_max_paragraph_crop_height(
+ self.iam_dataset.line_regions_by_id
+ )
+ if not max_crop_height <= max_crop_width:
+ raise RuntimeError(
+ f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
+ )
+
+ crop_dims = (max_crop_width, max_crop_width)
+ logger.info(
+ f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
+ )
+ logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
+ return crop_dims
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_subsample]
+ self.targets = self.targets[:num_subsample]
+
+ def __repr__(self) -> str:
+ """Return info about the dataset."""
+ return (
+ "IAM Paragraph Dataset\n" # pylint: disable=no-member
+ f"Num classes: {self.num_classes}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+
+def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
+ heights = []
+ for regions in line_regions_by_id.values():
+ min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ heights.append(height)
+ return max(heights)
+
+
+def _crop_paragraph_image(
+ filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
+) -> None:
+ image = util.read_image(filename, grayscale=True)
+
+ min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ crop_height = crop_dims[0]
+ buffer = (crop_height - height) // 2
+
+ # Generate image crop.
+ image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
+ try:
+ image_crop[buffer : buffer + height] = image[min_y1:max_y2]
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(f"Rescued {filename}: {e}")
+ return
+
+ # Generate ground truth.
+ gt_image = np.zeros_like(image_crop, dtype=np.uint8)
+ for index, region in enumerate(line_regions):
+ gt_image[
+ (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
+ region["x1"] : region["x2"],
+ ] = (index % 2 + 1)
+
+ # Generate image for debugging.
+ import matplotlib.pyplot as plt
+
+ cmap = plt.get_cmap("Set1")
+ image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
+ for index, region in enumerate(line_regions):
+ color = [255 * _ for _ in cmap(index)[:-1]]
+ cv2.rectangle(
+ image_crop_for_debug,
+ (region["x1"], region["y1"] - min_y1 + buffer),
+ (region["x2"], region["y2"] - min_y1 + buffer),
+ color,
+ 3,
+ )
+ image_crop_for_debug = cv2.resize(
+ image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
+ )
+ util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
+ util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
+ util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
+
+
+def _load_iam_paragraphs() -> None:
+ logger.info("Loading IAM paragraph crops and ground truth from image files...")
+ images = []
+ gt_images = []
+ ids = []
+ for filename in CROPS_DIRNAME.glob("*.jpg"):
+ id_ = filename.stem
+ image = util.read_image(filename, grayscale=True)
+ image = 1.0 - image / 255
+
+ gt_filename = GT_DIRNAME / f"{id_}.png"
+ gt_image = util.read_image(gt_filename, grayscale=True)
+
+ images.append(image)
+ gt_images.append(gt_image)
+ ids.append(id_)
+ images = np.array(images).astype(np.float32)
+ gt_images = np.array(gt_images).astype(np.uint8)
+ ids = np.array(ids)
+ return images, gt_images, ids
+
+
+@click.command()
+@click.option(
+ "--subsample_fraction",
+ type=float,
+ default=0.0,
+ help="The subsampling factor of the dataset.",
+)
+def main(subsample_fraction: float) -> None:
+ """Load dataset and print info."""
+ dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 76bd85f..dd16bed 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,10 +1,17 @@
"""Util functions for datasets."""
+import hashlib
import importlib
-from typing import Callable, Dict, List, Type
+import os
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+from urllib.request import urlopen, urlretrieve
+import cv2
+from loguru import logger
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
class Transpose:
@@ -15,58 +22,48 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
-def fetch_data_loaders(
- splits: List[str],
- dataset: str,
- dataset_args: Dict,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
-) -> Dict[str, DataLoader]:
- """Fetches DataLoaders for given split(s) as a dictionary.
-
- Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then
- calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value.
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- dataset (str): The name of the dataset.
- dataset_args (Dict): The dataset arguments.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
-
- Returns:
- Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value.
+def compute_sha256(filename: Union[Path, str]) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with open(filename, "rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
- """
-
- def check_dataset_args(args: Dict, split: str) -> Dict:
- """Adds train flag to the dataset args."""
- args["train"] = True if split == "train" else False
- return args
-
- # Import dataset module.
- datasets_module = importlib.import_module("text_recognizer.datasets")
- dataset_ = getattr(datasets_module, dataset)
- data_loaders = {}
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
- for split in ["train", "val"]:
- if split in splits:
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
- data_loader = DataLoader(
- dataset=dataset_(**check_dataset_args(dataset_args, split)),
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
-
- data_loaders[split] = data_loader
+ """
- return data_loaders
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def _download_raw_dataset(metadata: Dict) -> None:
+ if os.path.exists(metadata["filename"]):
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']}...")
+ download_url(metadata["url"], metadata["filename"])
+ logger.info("Computing SHA-256...")
+ sha256 = compute_sha256(metadata["filename"])
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index ff10a07..a3cfc15 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,6 +1,7 @@
"""Model modules."""
from .base import Model
from .character_model import CharacterModel
-from .metrics import accuracy
+from .line_ctc_model import LineCTCModel
+from .metrics import accuracy, cer, wer
-__all__ = ["Model", "CharacterModel", "accuracy"]
+__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 3a84a11..153e19a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from glob import glob
+import importlib
from pathlib import Path
import re
import shutil
@@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type
from loguru import logger
import torch
from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
-from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
+from text_recognizer.datasets import EmnistMapper
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -23,8 +27,9 @@ class Model(ABC):
def __init__(
self,
network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
network_args: Optional[Dict] = None,
- data_loader_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
@@ -32,14 +37,16 @@ class Model(ABC):
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Base class, to be inherited by model for specific type of data.
Args:
network_fn (Type[nn.Module]): The PyTorch network.
+ dataset (Type[Dataset]): A dataset class.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- data_loader_args (Optional[Dict]): Arguments for the DataLoader.
+ dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
Defaults to None.
@@ -49,107 +56,181 @@ class Model(ABC):
lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
None.
+ swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+ None.
device (Optional[str]): Name of the device to train on. Defaults to None.
"""
+ # Has to be set in subclass.
+ self._mapper = None
- # Configure data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
- data_loader_args
- )
- self._input_shape = self._mapper.input_shape
+ # Placeholder.
+ self._input_shape = None
+
+ self.dataset = dataset
+ self.dataset_args = dataset_args
+
+ # Placeholders for datasets.
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
- self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+ # Stochastic Weight Averaging placeholders.
+ self.swa_args = swa_args
+ self._swa_start = None
+ self._swa_scheduler = None
+ self._swa_network = None
- if metrics is not None:
- self._metrics = metrics
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for configured model.
+ self.is_configured = False
+ self.data_prepared = False
+
+ # Flag for stopping training.
+ self.stop_training = False
+
+ self._name = (
+ f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}"
+ )
+
+ self._metrics = metrics if metrics is not None else None
# Set the device.
- if device is None:
- self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- else:
- self._device = device
+ self._device = (
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device is None
+ else device
+ )
# Configure network.
- self._network, self._network_args = self._configure_network(
- network_fn, network_args
- )
+ self._network = None
+ self._network_args = network_args
+ self._configure_network(network_fn)
- # To device.
- self._network.to(self._device)
+ # Place network on device (GPU).
+ self.to_device()
+
+ # Loss and Optimizer placeholders for before loading.
+ self._criterion = criterion
+ self.criterion_args = criterion_args
+
+ self._optimizer = optimizer
+ self.optimizer_args = optimizer_args
+
+ self._lr_scheduler = lr_scheduler
+ self.lr_scheduler_args = lr_scheduler_args
+
+ def configure_model(self) -> None:
+ """Configures criterion and optimizers."""
+ if not self.is_configured:
+ self._configure_criterion()
+ self._configure_optimizers()
+
+ # Prints a summary of the network in terminal.
+ self.summary()
+
+ # Set this flag to true to prevent the model from configuring again.
+ self.is_configured = True
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ # TODO add downloading.
+ if not self.data_prepared:
+ # Load train dataset.
+ train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+
+ # Set input shape.
+ self._input_shape = train_dataset.input_shape
+
+ # Split train dataset into a training and validation partition.
+ dataset_len = len(train_dataset)
+ train_len = int(
+ self.dataset_args["train_args"]["train_fraction"] * dataset_len
+ )
+ val_len = dataset_len - train_len
+ self.train_dataset, self.val_dataset = random_split(
+ train_dataset, lengths=[train_len, val_len]
+ )
+
+ # Load test dataset.
+ self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+
+ # Set the flag to true to disable ability to load data agian.
+ self.data_prepared = True
- # Configure training objects.
- self._criterion = self._configure_criterion(criterion, criterion_args)
- self._optimizer, self._lr_scheduler = self._configure_optimizers(
- optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ def train_dataloader(self) -> DataLoader:
+ """Returns data loader for training set."""
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
)
- # Experiment directory.
- self.model_dir = None
+ def val_dataloader(self) -> DataLoader:
+ """Returns data loader for validation set."""
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
- # Flag for stopping training.
- self.stop_training = False
+ def test_dataloader(self) -> DataLoader:
+ """Returns data loader for test set."""
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=False,
+ pin_memory=True,
+ )
- def _configure_data_loader(
- self, data_loader_args: Optional[Dict]
- ) -> Tuple[str, Dict, EmnistMapper]:
- """Loads data loader, dataset name, and dataset mapper."""
- if data_loader_args is not None:
- data_loaders = fetch_data_loaders(**data_loader_args)
- dataset = list(data_loaders.values())[0].dataset
- dataset_name = dataset.__name__
- mapper = dataset.mapper
- else:
- self._mapper = EmnistMapper()
- dataset_name = "*"
- data_loaders = None
- return dataset_name, data_loaders, mapper
-
- def _configure_network(
- self, network_fn: Type[nn.Module], network_args: Optional[Dict]
- ) -> Tuple[Type[nn.Module], Dict]:
+ def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
# If no network arguemnts are given, load pretrained weights if they exist.
- if network_args is None:
- network, network_args = self.load_weights(network_fn)
+ if self._network_args is None:
+ self.load_weights(network_fn)
else:
- network = network_fn(**network_args)
- return network, network_args
+ self._network = network_fn(**self._network_args)
- def _configure_criterion(
- self, criterion: Optional[Callable], criterion_args: Optional[Dict]
- ) -> Optional[Callable]:
+ def _configure_criterion(self) -> None:
"""Loads the criterion."""
- if criterion is not None:
- _criterion = criterion(**criterion_args)
- else:
- _criterion = None
- return _criterion
+ self._criterion = (
+ self._criterion(**self.criterion_args)
+ if self._criterion is not None
+ else None
+ )
- def _configure_optimizers(
- self,
- optimizer: Optional[Callable],
- optimizer_args: Optional[Dict],
- lr_scheduler: Optional[Callable],
- lr_scheduler_args: Optional[Dict],
- ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ def _configure_optimizers(self,) -> None:
"""Loads the optimizers."""
- if optimizer is not None:
- _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+ if self._optimizer is not None:
+ self._optimizer = self._optimizer(
+ self._network.parameters(), **self.optimizer_args
+ )
else:
- _optimizer = None
+ self._optimizer = None
- if _optimizer and lr_scheduler is not None:
- if "OneCycleLR" in str(lr_scheduler):
- lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
- _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args)
+ if self._optimizer and self._lr_scheduler is not None:
+ if "OneCycleLR" in str(self._lr_scheduler):
+ self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+ self._lr_scheduler = self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ )
else:
- _lr_scheduler = None
+ self._lr_scheduler = None
- return _optimizer, _lr_scheduler
+ if self.swa_args is not None:
+ self._swa_start = self.swa_args["start"]
+ self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"])
+ self._swa_network = AveragedModel(self._network).to(self.device)
@property
- def __name__(self) -> str:
+ def name(self) -> str:
"""Returns the name of the model."""
return self._name
@@ -159,7 +240,7 @@ class Model(ABC):
return self._input_shape
@property
- def mapper(self) -> Dict:
+ def mapper(self) -> EmnistMapper:
"""Returns the mapper that maps between ints and chars."""
return self._mapper
@@ -202,13 +283,24 @@ class Model(ABC):
return self._lr_scheduler
@property
- def data_loaders(self) -> Optional[Dict]:
- """Dataloaders."""
- return self._data_loaders
+ def swa_scheduler(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging scheduler."""
+ return self._swa_scheduler
+
+ @property
+ def swa_start(self) -> Optional[Callable]:
+ """Returns the start epoch of stochastic weight averaging."""
+ return self._swa_start
@property
- def network(self) -> nn.Module:
+ def swa_network(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging network."""
+ return self._swa_network
+
+ @property
+ def network(self) -> Type[nn.Module]:
"""Neural network."""
+ # Returns the SWA network if available.
return self._network
@property
@@ -217,15 +309,27 @@ class Model(ABC):
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
- def summary(self) -> None:
+ def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Compute the loss."""
+ return self.criterion(output, targets)
+
+ def summary(
+ self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5
+ ) -> None:
"""Prints a summary of the network architecture."""
- device = re.sub("[^A-Za-z]+", "", self.device)
- if self._input_shape is not None:
+
+ if input_shape is not None:
+ summary(self._network, input_shape, depth=depth, device=self.device)
+ elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self._network, input_shape, device=device)
+ summary(self._network, input_shape, depth=depth, device=self.device)
else:
logger.warning("Could not print summary as input shape is not set.")
+ def to_device(self) -> None:
+ """Places the network on the device (GPU)."""
+ self._network.to(self._device)
+
def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
@@ -236,69 +340,67 @@ class Model(ABC):
if self._lr_scheduler is not None:
state["scheduler_state"] = self._lr_scheduler.state_dict()
+ if self._swa_network is not None:
+ state["swa_network"] = self._swa_network.state_dict()
+
return state
- def load_checkpoint(self, path: Path) -> int:
+ def load_from_checkpoint(self, checkpoint_path: Path) -> None:
"""Load a previously saved checkpoint.
Args:
- path (Path): Path to the experiment with the checkpoint.
-
- Returns:
- epoch (int): The last epoch when the checkpoint was created.
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
logger.debug("Loading checkpoint...")
- if not path.exists():
- logger.debug("File does not exist {str(path)}")
+ if not checkpoint_path.exists():
+ logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(path))
+ checkpoint = torch.load(str(checkpoint_path))
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
- # if self._lr_scheduler is not None:
- # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
-
- epoch = checkpoint["epoch"]
+ if self._lr_scheduler is not None:
+ # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # with OneCycleLR.
+ if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
- return epoch
+ if self._swa_network is not None:
+ self._swa_network.load_state_dict(checkpoint["swa_network"])
- def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
+ def save_checkpoint(
+ self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
"""Saves a checkpoint of the model.
Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
is_best (bool): If it is the currently best model.
epoch (int): The epoch of the checkpoint.
val_metric (str): Validation metric.
- Raises:
- ValueError: If the self.model_dir is not set.
-
"""
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
state["network_args"] = self._network_args
- if self.model_dir is None:
- raise ValueError("Experiment directory is not set.")
-
- self.model_dir.mkdir(parents=True, exist_ok=True)
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
logger.debug("Saving checkpoint...")
- filepath = str(self.model_dir / "last.pt")
+ filepath = str(checkpoint_path / "last.pt")
torch.save(state, filepath)
if is_best:
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
+ shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
+ def load_weights(self, network_fn: Type[nn.Module]) -> None:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -308,13 +410,16 @@ class Model(ABC):
)
# Loading state directory.
state_dict = torch.load(filename, map_location=torch.device(self._device))
- network_args = state_dict["network_args"]
+ self._network_args = state_dict["network_args"]
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- network = network_fn(**self._network_args)
- network.load_state_dict(weights)
- return network, network_args
+ self._network = network_fn(**self._network_args)
+ self._network.load_state_dict(weights)
+
+ if "swa_network" in state_dict:
+ self._swa_network = AveragedModel(self._network).to(self.device)
+ self._swa_network.load_state_dict(state_dict["swa_network"])
def save_weights(self, path: Path) -> None:
"""Save the network weights."""
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0fd7afd..64ba693 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
import torch
from torch import nn
+from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
+from text_recognizer.datasets import EmnistMapper
from text_recognizer.models.base import Model
@@ -15,8 +17,9 @@ class CharacterModel(Model):
def __init__(
self,
network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
network_args: Optional[Dict] = None,
- data_loader_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
@@ -24,14 +27,16 @@ class CharacterModel(Model):
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Initializes the CharacterModel."""
super().__init__(
network_fn,
+ dataset,
network_args,
- data_loader_args,
+ dataset_args,
metrics,
criterion,
criterion_args,
@@ -39,8 +44,11 @@ class CharacterModel(Model):
optimizer_args,
lr_scheduler,
lr_scheduler_args,
+ swa_args,
device,
)
+ if self._mapper is None:
+ self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
@@ -67,9 +75,13 @@ class CharacterModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- logits = self.network(image)
+ logits = (
+ self.swa_network(image)
+ if self.swa_network is not None
+ else self.network(image)
+ )
- prediction = self.softmax(logits.data.squeeze())
+ prediction = self.softmax(logits.squeeze(0))
index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
new file mode 100644
index 0000000..97308a7
--- /dev/null
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -0,0 +1,105 @@
+"""Defines the LineCTCModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class LineCTCModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ if self._mapper is None:
+ self._mapper = EmnistMapper()
+ self.tensor_transform = ToTensor()
+
+ def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+ target_lengths = torch.full(
+ size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ return self.criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = (
+ self.swa_network(image)
+ if self.swa_network is not None
+ else self.network(image)
+ )
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=79,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = torch.exp(log_probs.sum()).item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index ac8d68e..6a26216 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -1,19 +1,89 @@
"""Utility functions for models."""
-
+import Levenshtein as Lev
import torch
+from torch import Tensor
+
+from text_recognizer.networks import greedy_decoder
-def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
+def accuracy(outputs: Tensor, labels: Tensor) -> float:
"""Computes the accuracy.
Args:
- outputs (torch.Tensor): The output from the network.
- labels (torch.Tensor): Ground truth labels.
+ outputs (Tensor): The output from the network.
+ labels (Tensor): Ground truth labels.
Returns:
float: The accuracy for the batch.
"""
_, predicted = torch.max(outputs.data, dim=1)
- acc = (predicted == labels).sum().item() / labels.shape[0]
+ acc = (predicted == labels).sum().float() / labels.shape[0]
+ acc = acc.item()
return acc
+
+
+def cer(outputs: Tensor, targets: Tensor) -> float:
+ """Computes the character error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+
+ Returns:
+ float: The cer for the batch.
+
+ """
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+ prediction, target = (
+ prediction.replace(" ", ""),
+ target.replace(" ", ""),
+ )
+ lev_dist += Lev.distance(prediction, target)
+ return lev_dist / len(decoded_predictions)
+
+
+def wer(outputs: Tensor, targets: Tensor) -> float:
+ """Computes the Word error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+
+ Returns:
+ float: The wer for the batch.
+
+ """
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+
+ b = set(prediction.split() + target.split())
+ word2char = dict(zip(b, range(len(b))))
+
+ w1 = [chr(word2char[w]) for w in prediction.split()]
+ w2 = [chr(word2char[w]) for w in target.split()]
+
+ lev_dist += Lev.distance("".join(w1), "".join(w2))
+
+ return lev_dist / len(decoded_predictions)
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index a83ca35..d20c86a 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,6 +1,19 @@
"""Network modules."""
+from .ctc import greedy_decoder
from .lenet import LeNet
+from .line_lstm_ctc import LineRecurrentNetwork
+from .misc import sliding_window
from .mlp import MLP
-from .residual_network import ResidualNetwork
+from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .wide_resnet import WideResidualNetwork
-__all__ = ["MLP", "LeNet", "ResidualNetwork"]
+__all__ = [
+ "greedy_decoder",
+ "MLP",
+ "LeNet",
+ "LineRecurrentNetwork",
+ "ResidualNetwork",
+ "ResidualNetworkEncoder",
+ "sliding_window",
+ "WideResidualNetwork",
+]
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 00ad47e..fc0d21d 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -1,10 +1,58 @@
"""Decodes the CTC output."""
-#
-# from typing import Tuple
-# import torch
-#
-#
-# def greedy_decoder(
-# output, labels, label_length, blank_label, collapse_repeated=True
-# ) -> Tuple[torch.Tensor, torch.Tensor]:
-# pass
+from typing import Callable, List, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import Tensor
+
+from text_recognizer.datasets import EmnistMapper
+
+
+def greedy_decoder(
+ predictions: Tensor,
+ targets: Optional[Tensor] = None,
+ target_lengths: Optional[Tensor] = None,
+ character_mapper: Optional[Callable] = None,
+ blank_label: int = 79,
+ collapse_repeated: bool = True,
+) -> Tuple[List[str], List[str]]:
+ """Greedy CTC decoder.
+
+ Args:
+ predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
+ targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
+ target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
+ character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
+ to None.
+ blank_label (int): The blank character to be ignored. Defaults to 79.
+ collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
+
+ Returns:
+ Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
+
+ """
+
+ if character_mapper is None:
+ character_mapper = EmnistMapper()
+
+ predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
+ decoded_predictions = []
+ decoded_targets = []
+ for i, prediction in enumerate(predictions):
+ decoded_prediction = []
+ decoded_target = []
+ if targets is not None and target_lengths is not None:
+ for target_index in targets[i][: target_lengths[i]]:
+ if target_index == blank_label:
+ continue
+ decoded_target.append(character_mapper(int(target_index)))
+ decoded_targets.append(decoded_target)
+ for j, index in enumerate(prediction):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == prediction[j - 1]:
+ continue
+ decoded_prediction.append(index.item())
+ decoded_predictions.append(
+ [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
+ )
+ return decoded_predictions, decoded_targets
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 91d3f2c..53c575e 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,4 +1,4 @@
-"""Defines the LeNet network."""
+"""Implementation of the LeNet network."""
from typing import Callable, Dict, Optional, Tuple
from einops.layers.torch import Rearrange
@@ -9,7 +9,7 @@ from text_recognizer.networks.misc import activation_function
class LeNet(nn.Module):
- """LeNet network."""
+ """LeNet network for character prediction."""
def __init__(
self,
@@ -17,10 +17,10 @@ class LeNet(nn.Module):
kernel_sizes: Tuple[int, ...] = (3, 3, 2),
hidden_size: Tuple[int, ...] = (9216, 128),
dropout_rate: float = 0.2,
- output_size: int = 10,
+ num_classes: int = 10,
activation_fn: Optional[str] = "relu",
) -> None:
- """The LeNet network.
+ """Initialization of the LeNet network.
Args:
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
@@ -28,7 +28,7 @@ class LeNet(nn.Module):
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
Defaults to (9216, 128).
dropout_rate (float): The dropout rate. Defaults to 0.2.
- output_size (int): Number of classes. Defaults to 10.
+ num_classes (int): Number of classes. Defaults to 10.
activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
"""
@@ -55,7 +55,7 @@ class LeNet(nn.Module):
nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
activation_fn,
nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=hidden_size[1], out_features=output_size),
+ nn.Linear(in_features=hidden_size[1], out_features=num_classes),
]
self.layers = nn.Sequential(*self.layers)
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 2e2c3a5..988b615 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -1,5 +1,81 @@
"""LSTM with CTC for handwritten text recognition within a line."""
+import importlib
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange, Reduce
import torch
from torch import nn
from torch import Tensor
+
+
+class LineRecurrentNetwork(nn.Module):
+ """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+ def __init__(
+ self,
+ encoder: str,
+ encoder_args: Dict = None,
+ flatten: bool = True,
+ input_size: int = 128,
+ hidden_size: int = 128,
+ num_layers: int = 1,
+ num_classes: int = 80,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ ) -> None:
+ super().__init__()
+ self.encoder_args = encoder_args or {}
+ self.patch_size = patch_size
+ self.stride = stride
+ self.sliding_window = self._configure_sliding_window()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.encoder = self._configure_encoder(encoder)
+ self.flatten = flatten
+ self.rnn = nn.LSTM(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ num_layers=num_layers,
+ )
+ self.decoder = nn.Sequential(
+ nn.Linear(in_features=self.hidden_size, out_features=num_classes),
+ nn.LogSoftmax(dim=2),
+ )
+
+ def _configure_encoder(self, encoder: str) -> Type[nn.Module]:
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_ = getattr(network_module, encoder)
+ return encoder_(**self.encoder_args)
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+ x = self.encoder(x)
+
+ # Avgerage pooling.
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
+
+ # Sequence predictions.
+ x, _ = self.rnn(x)
+
+ # Sequence to classifcation layer.
+ x = self.decoder(x)
+ return x
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 6f61b5d..cac9e78 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -22,9 +22,10 @@ def sliding_window(
"""
unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the slidning window, unsqueeze as the channel dimesion is lost.
- patches = unfold(images).unsqueeze(1)
+ c = images.shape[1]
+ patches = unfold(images)
patches = rearrange(
- patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
)
return patches
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index acebdaa..d66af28 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -14,7 +14,7 @@ class MLP(nn.Module):
def __init__(
self,
input_size: int = 784,
- output_size: int = 10,
+ num_classes: int = 10,
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
@@ -24,7 +24,7 @@ class MLP(nn.Module):
Args:
input_size (int): The input shape of the network. Defaults to 784.
- output_size (int): Number of classes in the dataset. Defaults to 10.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
@@ -55,7 +55,7 @@ class MLP(nn.Module):
self.layers.append(nn.Dropout(p=dropout_rate))
self.layers.append(
- nn.Linear(in_features=hidden_size[-1], out_features=output_size)
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
)
self.layers = nn.Sequential(*self.layers)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 47e351a..1b5d6b3 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -8,6 +8,7 @@ from torch import nn
from torch import Tensor
from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.stn import SpatialTransformerNetwork
class Conv2dAuto(nn.Conv2d):
@@ -197,25 +198,28 @@ class ResidualLayer(nn.Module):
return x
-class Encoder(nn.Module):
+class ResidualNetworkEncoder(nn.Module):
"""Encoder network."""
def __init__(
self,
in_channels: int = 1,
- block_sizes: List[int] = (32, 64),
- depths: List[int] = (2, 2),
+ block_sizes: Union[int, List[int]] = (32, 64),
+ depths: Union[int, List[int]] = (2, 2),
activation: str = "relu",
block: Type[nn.Module] = BasicBlock,
+ levels: int = 1,
+ stn: bool = False,
*args,
**kwargs
) -> None:
super().__init__()
-
- self.block_sizes = block_sizes
- self.depths = depths
+ self.stn = SpatialTransformerNetwork() if stn else None
+ self.block_sizes = (
+ block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
+ )
+ self.depths = depths if isinstance(depths, list) else [depths] * levels
self.activation = activation
-
self.gate = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -227,7 +231,7 @@ class Encoder(nn.Module):
),
nn.BatchNorm2d(self.block_sizes[0]),
activation_function(self.activation),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
)
self.blocks = self._configure_blocks(block)
@@ -271,11 +275,13 @@ class Encoder(nn.Module):
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
+ if self.stn is not None:
+ x = self.stn(x)
x = self.gate(x)
return self.blocks(x)
-class Decoder(nn.Module):
+class ResidualNetworkDecoder(nn.Module):
"""Classification head."""
def __init__(self, in_features: int, num_classes: int = 80) -> None:
@@ -295,19 +301,12 @@ class ResidualNetwork(nn.Module):
def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
super().__init__()
- self.encoder = Encoder(in_channels, *args, **kwargs)
- self.decoder = Decoder(
+ self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
+ self.decoder = ResidualNetworkDecoder(
in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
num_classes=num_classes,
)
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
x = self.encoder(x)
diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py
new file mode 100644
index 0000000..b031128
--- /dev/null
+++ b/src/text_recognizer/networks/stn.py
@@ -0,0 +1,44 @@
+"""Spatial Transformer Network."""
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class SpatialTransformerNetwork(nn.Module):
+ """A network with differentiable attention.
+
+ Network that learns how to perform spatial transformations on the input image in order to enhance the
+ geometric invariance of the model.
+
+ # TODO: add arguements to make it more general.
+
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ # Initialize the identity transformation and its weights and biases.
+ linear = nn.Linear(32, 3 * 2)
+ linear.weight.data.zero_()
+ linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
+
+ self.theta = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7),
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5),
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ Rearrange("b c h w -> b (c h w)", h=3, w=3),
+ nn.Linear(in_features=10 * 3 * 3, out_features=32),
+ nn.ReLU(inplace=True),
+ linear,
+ Rearrange("b (row col) -> b row col", row=2, col=3),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """The spatial transformation."""
+ grid = F.affine_grid(self.theta(x), x.shape)
+ return F.grid_sample(x, grid, align_corners=False)
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
new file mode 100644
index 0000000..d1c8f9a
--- /dev/null
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -0,0 +1,214 @@
+"""Wide Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.misc import activation_function
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """Helper function for a 3x3 2d convolution."""
+ return nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ )
+
+
+def conv_init(module: Type[nn.Module]) -> None:
+ """Initializes the weights for convolution and batchnorms."""
+ classname = module.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
+ nn.init.constant(module.bias, 0)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.constant(module.weight, 1)
+ nn.init.constant(module.bias, 0)
+
+
+class WideBlock(nn.Module):
+ """Block used in WideResNet."""
+
+ def __init__(
+ self,
+ in_planes: int,
+ out_planes: int,
+ dropout_rate: float,
+ stride: int = 1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.dropout_rate = dropout_rate
+ self.stride = stride
+ self.activation = activation_function(activation)
+
+ # Build blocks.
+ self.blocks = nn.Sequential(
+ nn.BatchNorm2d(self.in_planes),
+ self.activation,
+ conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
+ nn.Dropout(p=self.dropout_rate),
+ nn.BatchNorm2d(self.out_planes),
+ self.activation,
+ conv3x3(
+ in_planes=self.out_planes,
+ out_planes=self.out_planes,
+ stride=self.stride,
+ ),
+ )
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_planes,
+ out_channels=self.out_planes,
+ kernel_size=1,
+ stride=self.stride,
+ bias=False,
+ ),
+ )
+ if self._apply_shortcut
+ else None
+ )
+
+ @property
+ def _apply_shortcut(self) -> bool:
+ """If shortcut should be applied or not."""
+ return self.stride != 1 or self.in_planes != self.out_planes
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self._apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ return x
+
+
+class WideResidualNetwork(nn.Module):
+ """WideResNet for character predictions.
+
+ Can be used for classification or encoding of images to a latent vector.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ in_planes: int = 16,
+ num_classes: int = 80,
+ depth: int = 16,
+ width_factor: int = 10,
+ dropout_rate: float = 0.0,
+ num_layers: int = 3,
+ block: Type[nn.Module] = WideBlock,
+ activation: str = "relu",
+ use_decoder: bool = True,
+ ) -> None:
+ """The initialization of the WideResNet.
+
+ Args:
+ in_channels (int): Number of input channels. Defaults to 1.
+ in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
+ num_classes (int): Number of classes. Defaults to 80.
+ depth (int): Set the number of blocks to use. Defaults to 16.
+ width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
+ dropout_rate (float): The dropout rate. Defaults to 0.0.
+ num_layers (int): Number of layers of blocks. Defaults to 3.
+ block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ activation (str): Name of the activation to use. Defaults to "relu".
+ use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
+ latent vector. Defaults to True.
+
+ Raises:
+ RuntimeError: If the depth is not of the size `6n+4`.
+
+ """
+
+ super().__init__()
+ if (depth - 4) % 6 != 0:
+ raise RuntimeError("Wide-resnet depth should be 6n+4")
+ self.in_channels = in_channels
+ self.in_planes = in_planes
+ self.num_classes = num_classes
+ self.num_blocks = (depth - 4) // 6
+ self.width_factor = width_factor
+ self.num_layers = num_layers
+ self.block = block
+ self.dropout_rate = dropout_rate
+ self.activation = activation_function(activation)
+
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers)
+ ]
+ self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
+ self.strides = [1] + [2] * (self.num_layers - 1)
+
+ self.encoder = nn.Sequential(
+ conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
+ *[
+ self._configure_wide_layer(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(
+ self.num_stages, self.strides
+ )
+ ],
+ )
+
+ self.decoder = (
+ nn.Sequential(
+ nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
+ self.activation,
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(
+ in_features=self.num_stages[-1][-1], out_features=self.num_classes
+ ),
+ )
+ if use_decoder
+ else None
+ )
+
+ self.apply(conv_init)
+
+ def _configure_wide_layer(
+ self, in_planes: int, out_planes: int, stride: int, activation: str
+ ) -> List:
+ strides = [stride] + [1] * (self.num_blocks - 1)
+ planes = [out_planes] * len(strides)
+ planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
+ return nn.Sequential(
+ *[
+ self.block(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ dropout_rate=self.dropout_rate,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(planes, strides)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass."""
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.encoder(x)
+ if self.decoder is not None:
+ x = self.decoder(x)
+ return x
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
index 86cf103..32c83cc 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
index a5c6aaf..a25bcd1 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt
new file mode 100644
index 0000000..e720299
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
new file mode 100644
index 0000000..9aec6ae
--- /dev/null
+++ b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ