From 8248f173132dfb7e47ec62b08e9235990c8626e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Mar 2021 22:15:54 +0100 Subject: renamed datasets to data, added iam refactor --- text_recognizer/datasets/base_dataset.py | 73 -------------------------------- 1 file changed, 73 deletions(-) delete mode 100644 text_recognizer/datasets/base_dataset.py (limited to 'text_recognizer/datasets/base_dataset.py') diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py deleted file mode 100644 index a9e9c24..0000000 --- a/text_recognizer/datasets/base_dataset.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union - -import torch -from torch import Tensor -from torch.utils.data import Dataset - - -class BaseDataset(Dataset): - """ - Base Dataset class that processes data and targets through optional transfroms. - - Args: - data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. - targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays. - tranform (Callable): Function that takes a datum and applies transforms. - target_transform (Callable): Fucntion that takes a target and applies - target transforms. - """ - - def __init__( - self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: - if len(data) != len(targets): - raise ValueError("Data and targets must be of equal length.") - self.data = data - self.targets = targets - self.transform = transform - self.target_transform = target_transform - - def __len__(self) -> int: - """Return the length of the dataset.""" - return len(self.data) - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """Return a datum and its target, after processing by transforms. - - Args: - index (int): Index of a datum in the dataset. - - Returns: - Tuple[Any, Any]: Datum and target pair. - - """ - datum, target = self.data[index], self.targets[index] - - if self.transform is not None: - datum = self.transform(datum) - - if self.target_transform is not None: - target = self.target_transform(target) - - return datum, target - - -def convert_strings_to_labels( - strings: Sequence[str], mapping: Dict[str, int], length: int -) -> Tensor: - """ - Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, - and padded wiht the

token. - """ - labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] - for i, string in enumerate(strings): - tokens = list(string) - tokens = ["", *tokens, ""] - for j, token in enumerate(tokens): - labels[i, j] = mapping[token] - return labels -- cgit v1.2.3-70-g09d2