diff options
Diffstat (limited to 'src/text_recognizer')
25 files changed, 1537 insertions, 242 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 05f74f6..ede4541 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -10,15 +10,23 @@ from .emnist_lines_dataset import ( EmnistLinesDataset, get_samples_by_character, ) -from .util import fetch_data_loaders, Transpose +from .iam_dataset import IamDataset +from .iam_lines_dataset import IamLinesDataset +from .iam_paragraphs_dataset import IamParagraphsDataset +from .util import _download_raw_dataset, compute_sha256, download_url, Transpose __all__ = [ + "_download_raw_dataset", + "compute_sha256", "construct_image_from_string", "DATA_DIRNAME", + "download_url", "EmnistDataset", "EmnistMapper", "EmnistLinesDataset", - "fetch_data_loaders", "get_samples_by_character", + "IamDataset", + "IamLinesDataset", + "IamParagraphsDataset", "Transpose", ] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 49ebad3..0715aae 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -152,8 +152,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to - False. + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. @@ -181,17 +180,37 @@ class EmnistDataset(Dataset): self.seed = seed self._mapper = EmnistMapper() - self.input_shape = self._mapper.input_shape + self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes # Load dataset. - self.data, self.targets = self.load_emnist_dataset() + self._data, self._targets = self.load_emnist_dataset() + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape @property def mapper(self) -> EmnistMapper: """Returns the EmnistMapper.""" return self._mapper + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -220,11 +239,6 @@ class EmnistDataset(Dataset): return data, targets - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistDataset" - def __repr__(self) -> str: """Returns information about the dataset.""" return ( diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index b0617f5..656131a 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,8 +9,8 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose, Normalize, ToTensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor from text_recognizer.datasets import ( DATA_DIRNAME, @@ -20,6 +20,7 @@ from text_recognizer.datasets import ( ) from text_recognizer.datasets.sentence_generator import SentenceGenerator from text_recognizer.datasets.util import Transpose +from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset): self.transform = transform if self.transform is None: - self.transform = Compose([ToTensor()]) + self.transform = ToTensor() self.target_transform = target_transform if self.target_transform is None: @@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset): # Extract dataset information. self._mapper = EmnistMapper() - self.input_shape = self._mapper.input_shape + self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes self.max_length = max_length self.min_overlap = min_overlap self.max_overlap = max_overlap self.num_samples = num_samples - self.input_shape = ( + self._input_shape = ( self.input_shape[0], self.input_shape[1] * self.max_length, ) @@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset): # Load dataset. self._load_or_generate_data() + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset): return data, targets - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistLinesDataset" - def __repr__(self) -> str: """Returns information about the dataset.""" return ( @@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset): return self._mapper @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" if self.train: filename = "train_" + filename else: - filename = "val_" + filename + filename = "test_" + filename return DATA_DIRNAME / filename def _load_or_generate_data(self) -> None: @@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset): ) targets = convert_strings_to_categorical_labels( - targets, self.emnist.inverse_mapping + targets, emnist.inverse_mapping ) f.create_dataset("data", data=data, dtype="u1", compression="lzf") @@ -322,13 +328,13 @@ def create_datasets( min_overlap: float = 0, max_overlap: float = 0.33, num_train: int = 10000, - num_val: int = 1000, + num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_val = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_val] - num_samples = [num_train, num_val] + emnist_test = EmnistDataset(train=False, sample_to_balance=True) + datasets = [emnist_train, emnist_test] + num_samples = [num_train, num_test] for num, train, dataset in zip(num_samples, [True, False], datasets): emnist_lines = EmnistLinesDataset( train=train, diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py new file mode 100644 index 0000000..5e47350 --- /dev/null +++ b/src/text_recognizer/datasets/iam_dataset.py @@ -0,0 +1,134 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml +from torch.utils.data import Dataset + +from text_recognizer.datasets import DATA_DIRNAME +from text_recognizer.datasets.util import _download_raw_dataset + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. + + +class IamDataset(Dataset): + """IAM dataset. + + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + + """ + + def __init__(self) -> None: + self.metadata = toml.load(METADATA_FILENAME) + + def load_or_generate_data(self) -> None: + """Downloads IAM dataset if xml files does not exist.""" + if not self.xml_filenames: + self._download_iam() + + @property + def xml_filenames(self) -> List: + """List of xml filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List: + """List of forms filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + def _download_iam(self) -> None: + curdir = os.getcwd() + os.chdir(RAW_DATA_DIRNAME) + _download_raw_dataset(self.metadata) + _extract_raw_dataset(self.metadata) + os.chdir(curdir) + + @property + def form_filenames_by_id(self) -> Dict: + """Creates a dictionary with filenames as keys and forms as values.""" + return {filename.stem: filename for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of line texts in it.""" + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } + + @cachedproperty + def line_regions_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } + + def __repr__(self) -> str: + """Print info about dataset.""" + return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: + logger.info("Extracting IAM data.") + with zipfile.ZipFile(metadata["filename"], "r") as zip_file: + zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace " with ".""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get the line region dict for each line.""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: + """Extracts coordinates for each line of text.""" + # TODO: fix input! + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def main() -> None: + """Initializes the dataset and print info about the dataset.""" + dataset = IamDataset() + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py new file mode 100644 index 0000000..477f500 --- /dev/null +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -0,0 +1,126 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.util import compute_sha256, download_url + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self.train = train + self.split = "train" if self.train else "test" + self._mapper = EmnistMapper() + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = transform + if self.transform is None: + self.transform = ToTensor() + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self.subsample_fraction = subsample_fraction + self.data = None + self.targets = None + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self.targets.shape[1:] + (self.num_classes,) + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self.data = f[f"x_{self.split}"][:] + self.targets = f[f"y_{self.split}"][:] + self._subsample() + + def _subsample(self) -> None: + """Only a fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + + num_samples = int(self.data.shape[0] * self.subsample_fraction) + self.data = self.data[:num_samples] + self.targets = self.targets[:num_samples] + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py new file mode 100644 index 0000000..d65b346 --- /dev/null +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -0,0 +1,322 @@ +"""IamParagraphsDataset class and functions for data processing.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import compute_sha256, download_url + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): + """IAM Paragraphs dataset for paragraphs of handwritten text. + + TODO: __getitem__, __len__, get_data_target_from_id + + """ + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + + # Load Iam dataset. + self.iam_dataset = IamDataset() + + self.train = train + self.split = "train" if self.train else "test" + self.num_classes = 3 + self._input_shape = (256, 256) + self._output_shape = self._input_shape + (self.num_classes,) + self.subsample_fraction = subsample_fraction + + # Set transforms. + self.transform = transform + if self.transform is None: + self.transform = ToTensor() + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self._data = None + self._targets = None + self._ids = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def ids(self) -> Tensor: + """Ids of the dataset.""" + return self._ids + + def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: + """Get data target pair from id.""" + ind = self.ids.index(id_) + return self.data[ind], self.targets[ind] + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) + num_targets = len(self.iam_dataset.line_regions_by_id) + + if num_actual < num_targets - 2: + self._process_iam_paragraphs() + + self._data, self._targets, self._ids = _load_iam_paragraphs() + self._get_random_split() + self._subsample() + + def _get_random_split(self) -> None: + np.random.seed(SEED) + num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) + indices = np.random.permutation(self.data.shape[0]) + train_indices, test_indices = indices[:num_train], indices[num_train:] + if self.train: + self._data = self.data[train_indices] + self._targets = self.targets[train_indices] + else: + self._data = self.data[test_indices] + self._targets = self.targets[test_indices] + + def _process_iam_paragraphs(self) -> None: + """Crop the part with the text. + + For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are + self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel + corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line + """ + crop_dims = self._decide_on_crop_dims() + CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + GT_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info( + f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" + ) + for filename in self.iam_dataset.form_filenames: + id_ = filename.stem + line_region = self.iam_dataset.line_regions_by_id[id_] + _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + + def _decide_on_crop_dims(self) -> Tuple[int, int]: + """Decide on the dimensions to crop out of the form image. + + Since image width is larger than a comfortable crop around the longest paragraph, + we will make the crop a square form factor. + And since the found dimensions 610x610 are pretty close to 512x512, + we might as well resize crops and make it exactly that, which lets us + do all kinds of power-of-2 pooling and upsampling should we choose to. + + Returns: + Tuple[int, int]: A tuple of crop dimensions. + + Raises: + RuntimeError: When max crop height is larger than max crop width. + + """ + + sample_form_filename = self.iam_dataset.form_filenames[0] + sample_image = util.read_image(sample_form_filename, grayscale=True) + max_crop_width = sample_image.shape[1] + max_crop_height = _get_max_paragraph_crop_height( + self.iam_dataset.line_regions_by_id + ) + if not max_crop_height <= max_crop_width: + raise RuntimeError( + f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" + ) + + crop_dims = (max_crop_width, max_crop_width) + logger.info( + f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." + ) + logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") + return crop_dims + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self.data = self.data[:num_subsample] + self.targets = self.targets[:num_subsample] + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ( + "IAM Paragraph Dataset\n" # pylint: disable=no-member + f"Num classes: {self.num_classes}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: + heights = [] + for regions in line_regions_by_id.values(): + min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + heights.append(height) + return max(heights) + + +def _crop_paragraph_image( + filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: + image = util.read_image(filename, grayscale=True) + + min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + crop_height = crop_dims[0] + buffer = (crop_height - height) // 2 + + # Generate image crop. + image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) + try: + image_crop[buffer : buffer + height] = image[min_y1:max_y2] + except Exception as e: # pylint: disable=broad-except + logger.error(f"Rescued {filename}: {e}") + return + + # Generate ground truth. + gt_image = np.zeros_like(image_crop, dtype=np.uint8) + for index, region in enumerate(line_regions): + gt_image[ + (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), + region["x1"] : region["x2"], + ] = (index % 2 + 1) + + # Generate image for debugging. + import matplotlib.pyplot as plt + + cmap = plt.get_cmap("Set1") + image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) + for index, region in enumerate(line_regions): + color = [255 * _ for _ in cmap(index)[:-1]] + cv2.rectangle( + image_crop_for_debug, + (region["x1"], region["y1"] - min_y1 + buffer), + (region["x2"], region["y2"] - min_y1 + buffer), + color, + 3, + ) + image_crop_for_debug = cv2.resize( + image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA + ) + util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + + image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) + util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + + gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) + util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: + logger.info("Loading IAM paragraph crops and ground truth from image files...") + images = [] + gt_images = [] + ids = [] + for filename in CROPS_DIRNAME.glob("*.jpg"): + id_ = filename.stem + image = util.read_image(filename, grayscale=True) + image = 1.0 - image / 255 + + gt_filename = GT_DIRNAME / f"{id_}.png" + gt_image = util.read_image(gt_filename, grayscale=True) + + images.append(image) + gt_images.append(gt_image) + ids.append(id_) + images = np.array(images).astype(np.float32) + gt_images = np.array(gt_images).astype(np.uint8) + ids = np.array(ids) + return images, gt_images, ids + + +@click.command() +@click.option( + "--subsample_fraction", + type=float, + default=0.0, + help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: + """Load dataset and print info.""" + dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 76bd85f..dd16bed 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,10 +1,17 @@ """Util functions for datasets.""" +import hashlib import importlib -from typing import Callable, Dict, List, Type +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union +from urllib.request import urlopen, urlretrieve +import cv2 +from loguru import logger import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm class Transpose: @@ -15,58 +22,48 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def fetch_data_loaders( - splits: List[str], - dataset: str, - dataset_args: Dict, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[str, DataLoader]: - """Fetches DataLoaders for given split(s) as a dictionary. - - Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then - calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - dataset (str): The name of the dataset. - dataset_args (Dict): The dataset arguments. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() - """ - - def check_dataset_args(args: Dict, split: str) -> Dict: - """Adds train flag to the dataset args.""" - args["train"] = True if split == "train" else False - return args - - # Import dataset module. - datasets_module = importlib.import_module("text_recognizer.datasets") - dataset_ = getattr(datasets_module, dataset) - data_loaders = {} +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. - for split in ["train", "val"]: - if split in splits: + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - data_loader = DataLoader( - dataset=dataset_(**check_dataset_args(dataset_args, split)), - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader + """ - return data_loaders + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index ff10a07..a3cfc15 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,6 +1,7 @@ """Model modules.""" from .base import Model from .character_model import CharacterModel -from .metrics import accuracy +from .line_ctc_model import LineCTCModel +from .metrics import accuracy, cer, wer -__all__ = ["Model", "CharacterModel", "accuracy"] +__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 3a84a11..153e19a 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from glob import glob +import importlib from pathlib import Path import re import shutil @@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type from loguru import logger import torch from torch import nn +from torch import Tensor +from torch.optim.swa_utils import AveragedModel, SWALR +from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary -from text_recognizer.datasets import EmnistMapper, fetch_data_loaders +from text_recognizer.datasets import EmnistMapper WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -23,8 +27,9 @@ class Model(ABC): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -32,14 +37,16 @@ class Model(ABC): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Base class, to be inherited by model for specific type of data. Args: network_fn (Type[nn.Module]): The PyTorch network. + dataset (Type[Dataset]): A dataset class. network_args (Optional[Dict]): Arguments for the network. Defaults to None. - data_loader_args (Optional[Dict]): Arguments for the DataLoader. + dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. Defaults to None. @@ -49,107 +56,181 @@ class Model(ABC): lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to None. + swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to + None. device (Optional[str]): Name of the device to train on. Defaults to None. """ + # Has to be set in subclass. + self._mapper = None - # Configure data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( - data_loader_args - ) - self._input_shape = self._mapper.input_shape + # Placeholder. + self._input_shape = None + + self.dataset = dataset + self.dataset_args = dataset_args + + # Placeholders for datasets. + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None - self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + # Stochastic Weight Averaging placeholders. + self.swa_args = swa_args + self._swa_start = None + self._swa_scheduler = None + self._swa_network = None - if metrics is not None: - self._metrics = metrics + # Experiment directory. + self.model_dir = None + + # Flag for configured model. + self.is_configured = False + self.data_prepared = False + + # Flag for stopping training. + self.stop_training = False + + self._name = ( + f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}" + ) + + self._metrics = metrics if metrics is not None else None # Set the device. - if device is None: - self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self._device = device + self._device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) # Configure network. - self._network, self._network_args = self._configure_network( - network_fn, network_args - ) + self._network = None + self._network_args = network_args + self._configure_network(network_fn) - # To device. - self._network.to(self._device) + # Place network on device (GPU). + self.to_device() + + # Loss and Optimizer placeholders for before loading. + self._criterion = criterion + self.criterion_args = criterion_args + + self._optimizer = optimizer + self.optimizer_args = optimizer_args + + self._lr_scheduler = lr_scheduler + self.lr_scheduler_args = lr_scheduler_args + + def configure_model(self) -> None: + """Configures criterion and optimizers.""" + if not self.is_configured: + self._configure_criterion() + self._configure_optimizers() + + # Prints a summary of the network in terminal. + self.summary() + + # Set this flag to true to prevent the model from configuring again. + self.is_configured = True + + def prepare_data(self) -> None: + """Prepare data for training.""" + # TODO add downloading. + if not self.data_prepared: + # Load train dataset. + train_dataset = self.dataset(train=True, **self.dataset_args["args"]) + + # Set input shape. + self._input_shape = train_dataset.input_shape + + # Split train dataset into a training and validation partition. + dataset_len = len(train_dataset) + train_len = int( + self.dataset_args["train_args"]["train_fraction"] * dataset_len + ) + val_len = dataset_len - train_len + self.train_dataset, self.val_dataset = random_split( + train_dataset, lengths=[train_len, val_len] + ) + + # Load test dataset. + self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) + + # Set the flag to true to disable ability to load data agian. + self.data_prepared = True - # Configure training objects. - self._criterion = self._configure_criterion(criterion, criterion_args) - self._optimizer, self._lr_scheduler = self._configure_optimizers( - optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + def train_dataloader(self) -> DataLoader: + """Returns data loader for training set.""" + return DataLoader( + self.train_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, ) - # Experiment directory. - self.model_dir = None + def val_dataloader(self) -> DataLoader: + """Returns data loader for validation set.""" + return DataLoader( + self.val_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) - # Flag for stopping training. - self.stop_training = False + def test_dataloader(self) -> DataLoader: + """Returns data loader for test set.""" + return DataLoader( + self.test_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=False, + pin_memory=True, + ) - def _configure_data_loader( - self, data_loader_args: Optional[Dict] - ) -> Tuple[str, Dict, EmnistMapper]: - """Loads data loader, dataset name, and dataset mapper.""" - if data_loader_args is not None: - data_loaders = fetch_data_loaders(**data_loader_args) - dataset = list(data_loaders.values())[0].dataset - dataset_name = dataset.__name__ - mapper = dataset.mapper - else: - self._mapper = EmnistMapper() - dataset_name = "*" - data_loaders = None - return dataset_name, data_loaders, mapper - - def _configure_network( - self, network_fn: Type[nn.Module], network_args: Optional[Dict] - ) -> Tuple[Type[nn.Module], Dict]: + def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" # If no network arguemnts are given, load pretrained weights if they exist. - if network_args is None: - network, network_args = self.load_weights(network_fn) + if self._network_args is None: + self.load_weights(network_fn) else: - network = network_fn(**network_args) - return network, network_args + self._network = network_fn(**self._network_args) - def _configure_criterion( - self, criterion: Optional[Callable], criterion_args: Optional[Dict] - ) -> Optional[Callable]: + def _configure_criterion(self) -> None: """Loads the criterion.""" - if criterion is not None: - _criterion = criterion(**criterion_args) - else: - _criterion = None - return _criterion + self._criterion = ( + self._criterion(**self.criterion_args) + if self._criterion is not None + else None + ) - def _configure_optimizers( - self, - optimizer: Optional[Callable], - optimizer_args: Optional[Dict], - lr_scheduler: Optional[Callable], - lr_scheduler_args: Optional[Dict], - ) -> Tuple[Optional[Callable], Optional[Callable]]: + def _configure_optimizers(self,) -> None: """Loads the optimizers.""" - if optimizer is not None: - _optimizer = optimizer(self._network.parameters(), **optimizer_args) + if self._optimizer is not None: + self._optimizer = self._optimizer( + self._network.parameters(), **self.optimizer_args + ) else: - _optimizer = None + self._optimizer = None - if _optimizer and lr_scheduler is not None: - if "OneCycleLR" in str(lr_scheduler): - lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) - _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args) + if self._optimizer and self._lr_scheduler is not None: + if "OneCycleLR" in str(self._lr_scheduler): + self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) + self._lr_scheduler = self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ) else: - _lr_scheduler = None + self._lr_scheduler = None - return _optimizer, _lr_scheduler + if self.swa_args is not None: + self._swa_start = self.swa_args["start"] + self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"]) + self._swa_network = AveragedModel(self._network).to(self.device) @property - def __name__(self) -> str: + def name(self) -> str: """Returns the name of the model.""" return self._name @@ -159,7 +240,7 @@ class Model(ABC): return self._input_shape @property - def mapper(self) -> Dict: + def mapper(self) -> EmnistMapper: """Returns the mapper that maps between ints and chars.""" return self._mapper @@ -202,13 +283,24 @@ class Model(ABC): return self._lr_scheduler @property - def data_loaders(self) -> Optional[Dict]: - """Dataloaders.""" - return self._data_loaders + def swa_scheduler(self) -> Optional[Callable]: + """Returns the stochastic weight averaging scheduler.""" + return self._swa_scheduler + + @property + def swa_start(self) -> Optional[Callable]: + """Returns the start epoch of stochastic weight averaging.""" + return self._swa_start @property - def network(self) -> nn.Module: + def swa_network(self) -> Optional[Callable]: + """Returns the stochastic weight averaging network.""" + return self._swa_network + + @property + def network(self) -> Type[nn.Module]: """Neural network.""" + # Returns the SWA network if available. return self._network @property @@ -217,15 +309,27 @@ class Model(ABC): WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - def summary(self) -> None: + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + """Compute the loss.""" + return self.criterion(output, targets) + + def summary( + self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5 + ) -> None: """Prints a summary of the network architecture.""" - device = re.sub("[^A-Za-z]+", "", self.device) - if self._input_shape is not None: + + if input_shape is not None: + summary(self._network, input_shape, depth=depth, device=self.device) + elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self._network, input_shape, device=device) + summary(self._network, input_shape, depth=depth, device=self.device) else: logger.warning("Could not print summary as input shape is not set.") + def to_device(self) -> None: + """Places the network on the device (GPU).""" + self._network.to(self._device) + def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} @@ -236,69 +340,67 @@ class Model(ABC): if self._lr_scheduler is not None: state["scheduler_state"] = self._lr_scheduler.state_dict() + if self._swa_network is not None: + state["swa_network"] = self._swa_network.state_dict() + return state - def load_checkpoint(self, path: Path) -> int: + def load_from_checkpoint(self, checkpoint_path: Path) -> None: """Load a previously saved checkpoint. Args: - path (Path): Path to the experiment with the checkpoint. - - Returns: - epoch (int): The last epoch when the checkpoint was created. + checkpoint_path (Path): Path to the experiment with the checkpoint. """ logger.debug("Loading checkpoint...") - if not path.exists(): - logger.debug("File does not exist {str(path)}") + if not checkpoint_path.exists(): + logger.debug("File does not exist {str(checkpoint_path)}") - checkpoint = torch.load(str(path)) + checkpoint = torch.load(str(checkpoint_path)) self._network.load_state_dict(checkpoint["model_state"]) if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs. - # if self._lr_scheduler is not None: - # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) - - epoch = checkpoint["epoch"] + if self._lr_scheduler is not None: + # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs + # with OneCycleLR. + if self._lr_scheduler.__class__.__name__ != "OneCycleLR": + self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) - return epoch + if self._swa_network is not None: + self._swa_network.load_state_dict(checkpoint["swa_network"]) - def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None: + def save_checkpoint( + self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: """Saves a checkpoint of the model. Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. is_best (bool): If it is the currently best model. epoch (int): The epoch of the checkpoint. val_metric (str): Validation metric. - Raises: - ValueError: If the self.model_dir is not set. - """ state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch state["network_args"] = self._network_args - if self.model_dir is None: - raise ValueError("Experiment directory is not set.") - - self.model_dir.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) logger.debug("Saving checkpoint...") - filepath = str(self.model_dir / "last.pt") + filepath = str(checkpoint_path / "last.pt") torch.save(state, filepath) if is_best: logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - shutil.copyfile(filepath, str(self.model_dir / "best.pt")) + shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]: + def load_weights(self, network_fn: Type[nn.Module]) -> None: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -308,13 +410,16 @@ class Model(ABC): ) # Loading state directory. state_dict = torch.load(filename, map_location=torch.device(self._device)) - network_args = state_dict["network_args"] + self._network_args = state_dict["network_args"] weights = state_dict["model_state"] # Initializes the network with trained weights. - network = network_fn(**self._network_args) - network.load_state_dict(weights) - return network, network_args + self._network = network_fn(**self._network_args) + self._network.load_state_dict(weights) + + if "swa_network" in state_dict: + self._swa_network = AveragedModel(self._network).to(self.device) + self._swa_network.load_state_dict(state_dict["swa_network"]) def save_weights(self, path: Path) -> None: """Save the network weights.""" diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0fd7afd..64ba693 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn +from torch.utils.data import Dataset from torchvision.transforms import ToTensor +from text_recognizer.datasets import EmnistMapper from text_recognizer.models.base import Model @@ -15,8 +17,9 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -24,14 +27,16 @@ class CharacterModel(Model): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Initializes the CharacterModel.""" super().__init__( network_fn, + dataset, network_args, - data_loader_args, + dataset_args, metrics, criterion, criterion_args, @@ -39,8 +44,11 @@ class CharacterModel(Model): optimizer_args, lr_scheduler, lr_scheduler_args, + swa_args, device, ) + if self._mapper is None: + self._mapper = EmnistMapper() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) @@ -67,9 +75,13 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = self.network(image) + logits = ( + self.swa_network(image) + if self.swa_network is not None + else self.network(image) + ) - prediction = self.softmax(logits.data.squeeze()) + prediction = self.softmax(logits.squeeze(0)) index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py new file mode 100644 index 0000000..97308a7 --- /dev/null +++ b/src/text_recognizer/models/line_ctc_model.py @@ -0,0 +1,105 @@ +"""Defines the LineCTCModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class LineCTCModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + if self._mapper is None: + self._mapper = EmnistMapper() + self.tensor_transform = ToTensor() + + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + target_lengths = torch.full( + size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + return self.criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = ( + self.swa_network(image) + if self.swa_network is not None + else self.network(image) + ) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=79, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = torch.exp(log_probs.sum()).item() + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index ac8d68e..6a26216 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -1,19 +1,89 @@ """Utility functions for models.""" - +import Levenshtein as Lev import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder -def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float: +def accuracy(outputs: Tensor, labels: Tensor) -> float: """Computes the accuracy. Args: - outputs (torch.Tensor): The output from the network. - labels (torch.Tensor): Ground truth labels. + outputs (Tensor): The output from the network. + labels (Tensor): Ground truth labels. Returns: float: The accuracy for the batch. """ _, predicted = torch.max(outputs.data, dim=1) - acc = (predicted == labels).sum().item() / labels.shape[0] + acc = (predicted == labels).sum().float() / labels.shape[0] + acc = acc.item() return acc + + +def cer(outputs: Tensor, targets: Tensor) -> float: + """Computes the character error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The cer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + prediction, target = ( + prediction.replace(" ", ""), + target.replace(" ", ""), + ) + lev_dist += Lev.distance(prediction, target) + return lev_dist / len(decoded_predictions) + + +def wer(outputs: Tensor, targets: Tensor) -> float: + """Computes the Word error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The wer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + + b = set(prediction.split() + target.split()) + word2char = dict(zip(b, range(len(b)))) + + w1 = [chr(word2char[w]) for w in prediction.split()] + w2 = [chr(word2char[w]) for w in target.split()] + + lev_dist += Lev.distance("".join(w1), "".join(w2)) + + return lev_dist / len(decoded_predictions) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index a83ca35..d20c86a 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,6 +1,19 @@ """Network modules.""" +from .ctc import greedy_decoder from .lenet import LeNet +from .line_lstm_ctc import LineRecurrentNetwork +from .misc import sliding_window from .mlp import MLP -from .residual_network import ResidualNetwork +from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .wide_resnet import WideResidualNetwork -__all__ = ["MLP", "LeNet", "ResidualNetwork"] +__all__ = [ + "greedy_decoder", + "MLP", + "LeNet", + "LineRecurrentNetwork", + "ResidualNetwork", + "ResidualNetworkEncoder", + "sliding_window", + "WideResidualNetwork", +] diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 00ad47e..fc0d21d 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -1,10 +1,58 @@ """Decodes the CTC output.""" -# -# from typing import Tuple -# import torch -# -# -# def greedy_decoder( -# output, labels, label_length, blank_label, collapse_repeated=True -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# pass +from typing import Callable, List, Optional, Tuple + +from einops import rearrange +import torch +from torch import Tensor + +from text_recognizer.datasets import EmnistMapper + + +def greedy_decoder( + predictions: Tensor, + targets: Optional[Tensor] = None, + target_lengths: Optional[Tensor] = None, + character_mapper: Optional[Callable] = None, + blank_label: int = 79, + collapse_repeated: bool = True, +) -> Tuple[List[str], List[str]]: + """Greedy CTC decoder. + + Args: + predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. + targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. + target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. + character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults + to None. + blank_label (int): The blank character to be ignored. Defaults to 79. + collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. + + Returns: + Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. + + """ + + if character_mapper is None: + character_mapper = EmnistMapper() + + predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") + decoded_predictions = [] + decoded_targets = [] + for i, prediction in enumerate(predictions): + decoded_prediction = [] + decoded_target = [] + if targets is not None and target_lengths is not None: + for target_index in targets[i][: target_lengths[i]]: + if target_index == blank_label: + continue + decoded_target.append(character_mapper(int(target_index))) + decoded_targets.append(decoded_target) + for j, index in enumerate(prediction): + if index != blank_label: + if collapse_repeated and j != 0 and index == prediction[j - 1]: + continue + decoded_prediction.append(index.item()) + decoded_predictions.append( + [character_mapper(int(pred_index)) for pred_index in decoded_prediction] + ) + return decoded_predictions, decoded_targets diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 91d3f2c..53c575e 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,4 +1,4 @@ -"""Defines the LeNet network.""" +"""Implementation of the LeNet network.""" from typing import Callable, Dict, Optional, Tuple from einops.layers.torch import Rearrange @@ -9,7 +9,7 @@ from text_recognizer.networks.misc import activation_function class LeNet(nn.Module): - """LeNet network.""" + """LeNet network for character prediction.""" def __init__( self, @@ -17,10 +17,10 @@ class LeNet(nn.Module): kernel_sizes: Tuple[int, ...] = (3, 3, 2), hidden_size: Tuple[int, ...] = (9216, 128), dropout_rate: float = 0.2, - output_size: int = 10, + num_classes: int = 10, activation_fn: Optional[str] = "relu", ) -> None: - """The LeNet network. + """Initialization of the LeNet network. Args: channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). @@ -28,7 +28,7 @@ class LeNet(nn.Module): hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. Defaults to (9216, 128). dropout_rate (float): The dropout rate. Defaults to 0.2. - output_size (int): Number of classes. Defaults to 10. + num_classes (int): Number of classes. Defaults to 10. activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. """ @@ -55,7 +55,7 @@ class LeNet(nn.Module): nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), activation_fn, nn.Dropout(p=dropout_rate), - nn.Linear(in_features=hidden_size[1], out_features=output_size), + nn.Linear(in_features=hidden_size[1], out_features=num_classes), ] self.layers = nn.Sequential(*self.layers) diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 2e2c3a5..988b615 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -1,5 +1,81 @@ """LSTM with CTC for handwritten text recognition within a line.""" +import importlib +from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from einops import rearrange, reduce +from einops.layers.torch import Rearrange, Reduce import torch from torch import nn from torch import Tensor + + +class LineRecurrentNetwork(nn.Module): + """Network that takes a image of a text line and predicts tokens that are in the image.""" + + def __init__( + self, + encoder: str, + encoder_args: Dict = None, + flatten: bool = True, + input_size: int = 128, + hidden_size: int = 128, + num_layers: int = 1, + num_classes: int = 80, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + ) -> None: + super().__init__() + self.encoder_args = encoder_args or {} + self.patch_size = patch_size + self.stride = stride + self.sliding_window = self._configure_sliding_window() + self.input_size = input_size + self.hidden_size = hidden_size + self.encoder = self._configure_encoder(encoder) + self.flatten = flatten + self.rnn = nn.LSTM( + input_size=self.input_size, + hidden_size=self.hidden_size, + num_layers=num_layers, + ) + self.decoder = nn.Sequential( + nn.Linear(in_features=self.hidden_size, out_features=num_classes), + nn.LogSoftmax(dim=2), + ) + + def _configure_encoder(self, encoder: str) -> Type[nn.Module]: + network_module = importlib.import_module("text_recognizer.networks") + encoder_ = getattr(network_module, encoder) + return encoder_(**self.encoder_args) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + x = self.encoder(x) + + # Avgerage pooling. + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x + + # Sequence predictions. + x, _ = self.rnn(x) + + # Sequence to classifcation layer. + x = self.decoder(x) + return x diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 6f61b5d..cac9e78 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -22,9 +22,10 @@ def sliding_window( """ unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. - patches = unfold(images).unsqueeze(1) + c = images.shape[1] + patches = unfold(images) patches = rearrange( - patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] ) return patches diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index acebdaa..d66af28 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -14,7 +14,7 @@ class MLP(nn.Module): def __init__( self, input_size: int = 784, - output_size: int = 10, + num_classes: int = 10, hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, @@ -24,7 +24,7 @@ class MLP(nn.Module): Args: input_size (int): The input shape of the network. Defaults to 784. - output_size (int): Number of classes in the dataset. Defaults to 10. + num_classes (int): Number of classes in the dataset. Defaults to 10. hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. @@ -55,7 +55,7 @@ class MLP(nn.Module): self.layers.append(nn.Dropout(p=dropout_rate)) self.layers.append( - nn.Linear(in_features=hidden_size[-1], out_features=output_size) + nn.Linear(in_features=hidden_size[-1], out_features=num_classes) ) self.layers = nn.Sequential(*self.layers) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 47e351a..1b5d6b3 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -8,6 +8,7 @@ from torch import nn from torch import Tensor from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.stn import SpatialTransformerNetwork class Conv2dAuto(nn.Conv2d): @@ -197,25 +198,28 @@ class ResidualLayer(nn.Module): return x -class Encoder(nn.Module): +class ResidualNetworkEncoder(nn.Module): """Encoder network.""" def __init__( self, in_channels: int = 1, - block_sizes: List[int] = (32, 64), - depths: List[int] = (2, 2), + block_sizes: Union[int, List[int]] = (32, 64), + depths: Union[int, List[int]] = (2, 2), activation: str = "relu", block: Type[nn.Module] = BasicBlock, + levels: int = 1, + stn: bool = False, *args, **kwargs ) -> None: super().__init__() - - self.block_sizes = block_sizes - self.depths = depths + self.stn = SpatialTransformerNetwork() if stn else None + self.block_sizes = ( + block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels + ) + self.depths = depths if isinstance(depths, list) else [depths] * levels self.activation = activation - self.gate = nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -227,7 +231,7 @@ class Encoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -271,11 +275,13 @@ class Encoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) + if self.stn is not None: + x = self.stn(x) x = self.gate(x) return self.blocks(x) -class Decoder(nn.Module): +class ResidualNetworkDecoder(nn.Module): """Classification head.""" def __init__(self, in_features: int, num_classes: int = 80) -> None: @@ -295,19 +301,12 @@ class ResidualNetwork(nn.Module): def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: super().__init__() - self.encoder = Encoder(in_channels, *args, **kwargs) - self.decoder = Decoder( + self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) + self.decoder = ResidualNetworkDecoder( in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, num_classes=num_classes, ) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - def forward(self, x: Tensor) -> Tensor: """Forward pass.""" x = self.encoder(x) diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py new file mode 100644 index 0000000..b031128 --- /dev/null +++ b/src/text_recognizer/networks/stn.py @@ -0,0 +1,44 @@ +"""Spatial Transformer Network.""" + +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class SpatialTransformerNetwork(nn.Module): + """A network with differentiable attention. + + Network that learns how to perform spatial transformations on the input image in order to enhance the + geometric invariance of the model. + + # TODO: add arguements to make it more general. + + """ + + def __init__(self) -> None: + super().__init__() + # Initialize the identity transformation and its weights and biases. + linear = nn.Linear(32, 3 * 2) + linear.weight.data.zero_() + linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) + + self.theta = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.ReLU(inplace=True), + Rearrange("b c h w -> b (c h w)", h=3, w=3), + nn.Linear(in_features=10 * 3 * 3, out_features=32), + nn.ReLU(inplace=True), + linear, + Rearrange("b (row col) -> b row col", row=2, col=3), + ) + + def forward(self, x: Tensor) -> Tensor: + """The spatial transformation.""" + grid = F.affine_grid(self.theta(x), x.shape) + return F.grid_sample(x, grid, align_corners=False) diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py new file mode 100644 index 0000000..d1c8f9a --- /dev/null +++ b/src/text_recognizer/networks/wide_resnet.py @@ -0,0 +1,214 @@ +"""Wide Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.misc import activation_function + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """Helper function for a 3x3 2d convolution.""" + return nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + + +def conv_init(module: Type[nn.Module]) -> None: + """Initializes the weights for convolution and batchnorms.""" + classname = module.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) + nn.init.constant(module.bias, 0) + elif classname.find("BatchNorm") != -1: + nn.init.constant(module.weight, 1) + nn.init.constant(module.bias, 0) + + +class WideBlock(nn.Module): + """Block used in WideResNet.""" + + def __init__( + self, + in_planes: int, + out_planes: int, + dropout_rate: float, + stride: int = 1, + activation: str = "relu", + ) -> None: + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.dropout_rate = dropout_rate + self.stride = stride + self.activation = activation_function(activation) + + # Build blocks. + self.blocks = nn.Sequential( + nn.BatchNorm2d(self.in_planes), + self.activation, + conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), + nn.Dropout(p=self.dropout_rate), + nn.BatchNorm2d(self.out_planes), + self.activation, + conv3x3( + in_planes=self.out_planes, + out_planes=self.out_planes, + stride=self.stride, + ), + ) + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_planes, + out_channels=self.out_planes, + kernel_size=1, + stride=self.stride, + bias=False, + ), + ) + if self._apply_shortcut + else None + ) + + @property + def _apply_shortcut(self) -> bool: + """If shortcut should be applied or not.""" + return self.stride != 1 or self.in_planes != self.out_planes + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self._apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + return x + + +class WideResidualNetwork(nn.Module): + """WideResNet for character predictions. + + Can be used for classification or encoding of images to a latent vector. + + """ + + def __init__( + self, + in_channels: int = 1, + in_planes: int = 16, + num_classes: int = 80, + depth: int = 16, + width_factor: int = 10, + dropout_rate: float = 0.0, + num_layers: int = 3, + block: Type[nn.Module] = WideBlock, + activation: str = "relu", + use_decoder: bool = True, + ) -> None: + """The initialization of the WideResNet. + + Args: + in_channels (int): Number of input channels. Defaults to 1. + in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. + num_classes (int): Number of classes. Defaults to 80. + depth (int): Set the number of blocks to use. Defaults to 16. + width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. + dropout_rate (float): The dropout rate. Defaults to 0.0. + num_layers (int): Number of layers of blocks. Defaults to 3. + block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + activation (str): Name of the activation to use. Defaults to "relu". + use_decoder (bool): If True, the network output character predictions, if False, the network outputs a + latent vector. Defaults to True. + + Raises: + RuntimeError: If the depth is not of the size `6n+4`. + + """ + + super().__init__() + if (depth - 4) % 6 != 0: + raise RuntimeError("Wide-resnet depth should be 6n+4") + self.in_channels = in_channels + self.in_planes = in_planes + self.num_classes = num_classes + self.num_blocks = (depth - 4) // 6 + self.width_factor = width_factor + self.num_layers = num_layers + self.block = block + self.dropout_rate = dropout_rate + self.activation = activation_function(activation) + + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers) + ] + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) + self.strides = [1] + [2] * (self.num_layers - 1) + + self.encoder = nn.Sequential( + conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), + *[ + self._configure_wide_layer( + in_planes=in_planes, + out_planes=out_planes, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip( + self.num_stages, self.strides + ) + ], + ) + + self.decoder = ( + nn.Sequential( + nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), + self.activation, + Reduce("b c h w -> b c", "mean"), + nn.Linear( + in_features=self.num_stages[-1][-1], out_features=self.num_classes + ), + ) + if use_decoder + else None + ) + + self.apply(conv_init) + + def _configure_wide_layer( + self, in_planes: int, out_planes: int, stride: int, activation: str + ) -> List: + strides = [stride] + [1] * (self.num_blocks - 1) + planes = [out_planes] * len(strides) + planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) + return nn.Sequential( + *[ + self.block( + in_planes=in_planes, + out_planes=out_planes, + dropout_rate=self.dropout_rate, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip(planes, strides) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.encoder(x) + if self.decoder is not None: + x = self.decoder(x) + return x diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt Binary files differindex 86cf103..32c83cc 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt Binary files differindex a5c6aaf..a25bcd1 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt Binary files differnew file mode 100644 index 0000000..e720299 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt Binary files differnew file mode 100644 index 0000000..9aec6ae --- /dev/null +++ b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt |