summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/data
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/__init__.py3
-rw-r--r--text_recognizer/data/emnist_lines.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py15
-rw-r--r--text_recognizer/data/iam_paragraphs.py23
-rw-r--r--text_recognizer/data/iam_preprocessor.py1
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py7
-rw-r--r--text_recognizer/data/mappings.py16
-rw-r--r--text_recognizer/data/transforms.py14
8 files changed, 54 insertions, 27 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py
index 9a42fa9..3599a8b 100644
--- a/text_recognizer/data/__init__.py
+++ b/text_recognizer/data/__init__.py
@@ -2,3 +2,6 @@
from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset
from .base_data_module import BaseDataModule, load_and_print_info
from .download_utils import download_dataset
+from .iam_paragraphs import IAMParagraphs
+from .iam_synthetic_paragraphs import IAMSyntheticParagraphs
+from .iam_extended_paragraphs import IAMExtendedParagraphs
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 72665d0..9650198 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -57,8 +57,8 @@ class EMNISTLines(BaseDataModule):
self.num_test = num_test
self.emnist = EMNIST()
- # TODO: fix mapping
self.mapping = self.emnist.mapping
+
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ IMAGE_X_PADDING
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index d2529b4..2380660 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -10,18 +10,27 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
class IAMExtendedParagraphs(BaseDataModule):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
super().__init__(batch_size, num_workers)
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment,
+ batch_size,
+ num_workers,
+ train_fraction,
+ augment,
+ word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment,
+ batch_size,
+ num_workers,
+ train_fraction,
+ augment,
+ word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index f588587..62c44f9 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
from loguru import logger
import numpy as np
-from PIL import Image, ImageFile, ImageOps
-import torch
+from PIL import Image, ImageOps
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
@@ -19,6 +18,7 @@ from text_recognizer.data.base_dataset import (
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam import IAM
+from text_recognizer.data.transforms import WordPiece
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
@@ -37,15 +37,15 @@ class IAMParagraphs(BaseDataModule):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
super().__init__(batch_size, num_workers)
- # TODO: pass in transform and target transform
- # TODO: pass in mapping
self.augment = augment
+ self.word_pieces = word_pieces
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
@@ -101,6 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
+ target_transform=get_target_transform(self.word_pieces)
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -161,7 +162,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -282,7 +286,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -290,6 +296,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
+ """Transform emnist characters to word pieces."""
+ return transforms.Compose([WordPiece()]) if word_pieces else None
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 60f8a9f..b5f72da 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -89,6 +89,7 @@ class Preprocessor:
self.lexicon = None
if self.special_tokens is not None:
+ self.special_tokens += ("#", "*")
self.tokens += self.special_tokens
self.graphemes += self.special_tokens
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 9f1bd12..4ccc5c2 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -18,6 +18,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.iam_paragraphs import (
get_dataset_properties,
get_transform,
+ get_target_transform,
NEW_LINE_TOKEN,
IAMParagraphs,
IMAGE_SCALE_FACTOR,
@@ -41,12 +42,13 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def __init__(
self,
- batch_size: int = 128,
+ batch_size: int = 16,
num_workers: int = 0,
train_fraction: float = 0.8,
augment: bool = True,
+ word_pieces: bool = False,
) -> None:
- super().__init__(batch_size, num_workers, train_fraction, augment)
+ super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces)
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
@@ -95,6 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
transform=get_transform(
image_shape=self.dims[1:], augment=self.augment
),
+ target_transform=get_target_transform(self.word_pieces)
)
def __repr__(self) -> str:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index cfa0ec7..f4016ba 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -8,7 +8,7 @@ import torch
from torch import Tensor
from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
+from text_recognizer.data.iam_preprocessor import Preprocessor
class AbstractMapping(ABC):
@@ -57,14 +57,14 @@ class EmnistMapping(AbstractMapping):
class WordPieceMapping(EmnistMapping):
def __init__(
self,
- num_features: int,
- tokens: str,
- lexicon: str,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Sequence[str]] = ("\n", ),
) -> None:
super().__init__(extra_symbols)
self.wordpiece_processor = self._configure_wordpiece_processor(
@@ -78,8 +78,8 @@ class WordPieceMapping(EmnistMapping):
extra_symbols,
)
+ @staticmethod
def _configure_wordpiece_processor(
- self,
num_features: int,
tokens: str,
lexicon: str,
@@ -90,7 +90,7 @@ class WordPieceMapping(EmnistMapping):
extra_symbols: Optional[Sequence[str]],
) -> Preprocessor:
data_dir = (
- (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb")
+ (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb")
if data_dir is None
else Path(data_dir)
)
@@ -138,6 +138,6 @@ class WordPieceMapping(EmnistMapping):
return self.wordpiece_processor.to_index(text)
def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
- text = self.mapping.get_text(x)
+ text = "".join([self.mapping[i] for i in x])
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index f53df64..8d1bedd 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence
from torch import Tensor
-from text_recognizer.datasets.mappings import WordPieceMapping
+from text_recognizer.data.mappings import WordPieceMapping
class WordPiece:
@@ -12,14 +12,15 @@ class WordPiece:
def __init__(
self,
- num_features: int,
- tokens: str,
- lexicon: str,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Sequence[str]] = ("\n",),
+ max_len: int = 192,
) -> None:
self.mapping = WordPieceMapping(
num_features,
@@ -31,6 +32,7 @@ class WordPiece:
special_tokens,
extra_symbols,
)
+ self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)
+ return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len]