diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 11 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/emnist_essentials.json (renamed from text_recognizer/data/mappings/emnist_essentials.json) | 2 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 16 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 12 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 15 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 12 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 12 | ||||
-rw-r--r-- | text_recognizer/data/mappings/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/tokenizer.py (renamed from text_recognizer/data/mappings/emnist.py) | 32 | ||||
-rw-r--r-- | text_recognizer/metadata/shared.py | 3 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 9 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 66 |
14 files changed, 108 insertions, 100 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 7863333..bd6fd99 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from text_recognizer.data.base_dataset import BaseDataset -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer T = TypeVar("T") @@ -24,7 +24,7 @@ class BaseDataModule(LightningDataModule): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -34,7 +34,7 @@ class BaseDataModule(LightningDataModule): pin_memory: bool = True, ) -> None: super().__init__() - self.mapping = mapping + self.tokenizer = tokenizer self.transform = transform self.test_transform = test_transform self.target_transform = target_transform @@ -50,11 +50,6 @@ class BaseDataModule(LightningDataModule): self.dims: Tuple[int, ...] = None self.output_dims: Tuple[int, ...] = None - @classmethod - def data_dirname(cls: T) -> Path: - """Return the path to the base data directory.""" - return Path(__file__).resolve().parents[2] / "data" - def config(self) -> Dict: """Return important settings of the dataset.""" return { diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 143705e..b5db075 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -13,7 +13,6 @@ from loguru import logger as log from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, split_dataset -from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.download_utils import download_dataset import text_recognizer.metadata.emnist as metadata @@ -32,7 +31,7 @@ class EMNIST(BaseDataModule): def __init__(self) -> None: super().__init__() - self.dims = (1, *self.mapping.input_size) + self.dims = (1, *self.tokenizer.input_size) def prepare_data(self) -> None: """Downloads dataset if not present.""" @@ -65,8 +64,8 @@ class EMNIST(BaseDataModule): """Returns string with info about the dataset.""" basic = ( "EMNIST Dataset\n" - f"Num classes: {len(self.mapping)}\n" - f"Mapping: {self.mapping}\n" + f"Num classes: {len(self.tokenizer)}\n" + f"Mapping: {self.tokenizer}\n" f"Dims: {self.dims}\n" ) if not any([self.data_train, self.data_val, self.data_test]): @@ -193,5 +192,4 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: def download_emnist() -> None: """Download dataset from internet, if it does not exists, and displays info.""" - transform = load_transform_from_file("transform/default.yaml") - load_and_print_info(EMNIST(transform=transform, test_transfrom=transform)) + load_and_print_info(EMNIST()) diff --git a/text_recognizer/data/mappings/emnist_essentials.json b/text_recognizer/data/emnist_essentials.json index c412425..956c28d 100644 --- a/text_recognizer/data/mappings/emnist_essentials.json +++ b/text_recognizer/data/emnist_essentials.json @@ -1 +1 @@ -{"characters": ["<b>", "<s>", "<e>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
\ No newline at end of file +{"characters": ["<b>", "<s>", "<e>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 88aac0d..8a31c44 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -12,7 +12,7 @@ from torch import Tensor from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.stems.line import LineStem from text_recognizer.data.utils.sentence_generator import SentenceGenerator import text_recognizer.metadata.emnist_lines as metadata @@ -23,7 +23,7 @@ class EMNISTLines(BaseDataModule): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -39,7 +39,7 @@ class EMNISTLines(BaseDataModule): num_test: int = 2_000, ) -> None: super().__init__( - mapping, + tokenizer, transform, test_transform, target_transform, @@ -120,7 +120,7 @@ class EMNISTLines(BaseDataModule): "EMNISTLines2 Dataset\n" # pylint: disable=no-member f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" - f"Num classes: {len(self.mapping)}\n" + f"Num classes: {len(self.tokenizer)}\n" f"Dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) @@ -153,17 +153,17 @@ class EMNISTLines(BaseDataModule): if split == "train": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, self.mapping.mapping + emnist.x_train, emnist.y_train, self.tokenizer.mapping ) num = self.num_train elif split == "val": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, self.mapping.mapping + emnist.x_train, emnist.y_train, self.tokenizer.mapping ) num = self.num_val else: samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, self.mapping.mapping + emnist.x_test, emnist.y_test, self.tokenizer.mapping ) num = self.num_test @@ -178,7 +178,7 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH + y, self.tokenizer.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 2ce1e9c..8a31205 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -4,14 +4,14 @@ Which encompasses both paragraphs and lines, with associated utilities. """ import os -import xml.etree.ElementTree as ElementTree -import zipfile from pathlib import Path from typing import Any, Dict, List +import xml.etree.ElementTree as ElementTree +import zipfile -import toml from boltons.cacheutils import cachedproperty from loguru import logger as log +import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.utils.download_utils import download_dataset diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 658626c..c6628a8 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -7,7 +7,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.pad import Pad -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.stems.paragraph import ParagraphStem import text_recognizer.metadata.iam_paragraphs as metadata @@ -17,7 +17,7 @@ class IAMExtendedParagraphs(BaseDataModule): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -27,7 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule): pin_memory: bool = True, ) -> None: super().__init__( - mapping, + tokenizer, transform, test_transform, target_transform, @@ -37,7 +37,7 @@ class IAMExtendedParagraphs(BaseDataModule): pin_memory, ) self.iam_paragraphs = IAMParagraphs( - mapping=self.mapping, + tokenizer=self.tokenizer, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -46,7 +46,7 @@ class IAMExtendedParagraphs(BaseDataModule): target_transform=self.target_transform, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - mapping=self.mapping, + tokenizer=self.tokenizer, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -78,7 +78,7 @@ class IAMExtendedParagraphs(BaseDataModule): """Returns info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member - f"Num classes: {len(self.mapping)}\n" + f"Num classes: {len(self.tokenizer)}\n" f"Dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index e60d1ba..a0d9b59 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -19,8 +19,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.stems.line import IamLinesStem from text_recognizer.data.utils import image_utils import text_recognizer.metadata.iam_lines as metadata @@ -33,7 +32,7 @@ class IAMLines(BaseDataModule): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -43,7 +42,7 @@ class IAMLines(BaseDataModule): pin_memory: bool = True, ) -> None: super().__init__( - mapping, + tokenizer, transform, test_transform, target_transform, @@ -61,7 +60,7 @@ class IAMLines(BaseDataModule): return log.info("Cropping IAM lines regions...") - iam = IAM(mapping=EmnistMapping()) + iam = IAM(tokenizer=self.tokenizer) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") @@ -100,7 +99,7 @@ class IAMLines(BaseDataModule): raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( - labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] + labels_train, self.tokenizer.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( x_train, @@ -122,7 +121,7 @@ class IAMLines(BaseDataModule): raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( - labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] + labels_test, self.tokenizer.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( x_test, @@ -144,7 +143,7 @@ class IAMLines(BaseDataModule): """Return information about the dataset.""" basic = ( "IAM Lines dataset\n" - f"Num classes: {len(self.mapping)}\n" + f"Num classes: {len(self.tokenizer)}\n" f"Input dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index fe1f15c..a078c7d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -17,7 +17,7 @@ from text_recognizer.data.base_dataset import ( ) from text_recognizer.data.iam import IAM from text_recognizer.data.transforms.pad import Pad -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.stems.paragraph import ParagraphStem import text_recognizer.metadata.iam_paragraphs as metadata @@ -27,7 +27,7 @@ class IAMParagraphs(BaseDataModule): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -37,7 +37,7 @@ class IAMParagraphs(BaseDataModule): pin_memory: bool = True, ) -> None: super().__init__( - mapping, + tokenizer, transform, test_transform, target_transform, @@ -56,7 +56,7 @@ class IAMParagraphs(BaseDataModule): log.info("Cropping IAM paragraph regions and saving them along with labels...") - iam = IAM(mapping=EmnistMapping(extra_symbols={metadata.NEW_LINE_TOKEN})) + iam = IAM(tokenizer=self.tokenizer) iam.prepare_data() properties = {} @@ -88,7 +88,7 @@ class IAMParagraphs(BaseDataModule): data = [resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( strings=labels, - mapping=self.mapping.inverse_mapping, + mapping=self.tokenizer.inverse_mapping, length=self.output_dims[0], ) return BaseDataset( @@ -122,7 +122,7 @@ class IAMParagraphs(BaseDataModule): """Return information about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" - f"Num classes: {len(self.mapping)}\n" + f"Num classes: {len(self.tokenizer)}\n" f"Input dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 91fda4a..511a8d4 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -19,7 +19,7 @@ from text_recognizer.data.iam_lines import ( load_line_crops_and_labels, save_images_and_labels, ) -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.stems.paragraph import ParagraphStem from text_recognizer.data.transforms.pad import Pad import text_recognizer.metadata.iam_synthetic_paragraphs as metadata @@ -30,7 +30,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -40,7 +40,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): pin_memory: bool = True, ) -> None: super().__init__( - mapping, + tokenizer, transform, test_transform, target_transform, @@ -58,7 +58,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): log.info("Preparing IAM lines for synthetic paragraphs dataset.") log.info("Cropping IAM line regions and loading labels.") - iam = IAM(mapping=EmnistMapping(extra_symbols=(metadata.NEW_LINE_TOKEN,))) + iam = IAM(tokenizer=self.tokenizer) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") @@ -94,7 +94,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): targets = convert_strings_to_labels( strings=paragraphs_labels, - mapping=self.mapping.inverse_mapping, + mapping=self.tokenizer.inverse_mapping, length=self.output_dims[0], ) self.data_train = BaseDataset( @@ -108,7 +108,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): """Return information about the dataset.""" basic = ( "IAM Synthetic Paragraphs Dataset\n" # pylint: disable=no-member - f"Num classes: {len(self.mapping)}\n" + f"Num classes: {len(self.tokenizer)}\n" f"Input dims : {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/mappings/__init__.py b/text_recognizer/data/mappings/__init__.py deleted file mode 100644 index 635f506..0000000 --- a/text_recognizer/data/mappings/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Mapping modules.""" -from text_recognizer.data.mappings.emnist import EmnistMapping diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/tokenizer.py index 331976e..a5f44e6 100644 --- a/text_recognizer/data/mappings/emnist.py +++ b/text_recognizer/data/tokenizer.py @@ -6,19 +6,29 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +import text_recognizer.metadata.shared as metadata -class EmnistMapping: +class Tokenizer: """Mapping for EMNIST labels.""" def __init__( self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True, + start_token: str = "<s>", + end_token: str = "<e>", + pad_token: str = "<p>", ) -> None: self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = self._load_mapping() + self.start_token = start_token + self.end_token = end_token + self.pad_token = pad_token + self.start_index = int(self.get_value(self.start_token)) + self.end_index = int(self.get_value(self.end_token)) + self.pad_index = int(self.get_value(self.pad_token)) + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) if lower: self._to_lower() @@ -31,7 +41,7 @@ class EmnistMapping: def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" - with ESSENTIALS_FILENAME.open() as f: + with metadata.ESSENTIALS_FILENAME.open() as f: essentials = json.load(f) mapping = list(essentials["characters"]) if self.extra_symbols is not None: @@ -57,19 +67,25 @@ class EmnistMapping: return self.mapping[index] raise KeyError(f"Index ({index}) not in mapping.") - def get_index(self, token: str) -> Tensor: + def get_value(self, token: str) -> Tensor: """Returns index value of token.""" if token in self.inverse_mapping: return torch.LongTensor([self.inverse_mapping[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") - def get_text(self, indices: Union[List[int], Tensor]) -> str: + def decode(self, indices: Union[List[int], Tensor]) -> str: """Returns the text from a list of indices.""" if isinstance(indices, Tensor): indices = indices.tolist() - return "".join([self.mapping[index] for index in indices]) - - def get_indices(self, text: str) -> Tensor: + return "".join( + [ + self.mapping[index] + for index in indices + if index not in self.ignore_indices + ] + ) + + def encode(self, text: str) -> Tensor: """Returns tensor of indices for a string.""" return Tensor([self.inverse_mapping[token] for token in text]) diff --git a/text_recognizer/metadata/shared.py b/text_recognizer/metadata/shared.py index a4d1da0..cee5de4 100644 --- a/text_recognizer/metadata/shared.py +++ b/text_recognizer/metadata/shared.py @@ -1,4 +1,7 @@ from pathlib import Path +ESSENTIALS_FILENAME = ( + Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" +) DATA_DIRNAME = Path(__file__).resolve().parents[2] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloded" diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bb4e695..f8f4b40 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -9,7 +9,7 @@ from pytorch_lightning import LightningModule from torch import nn, Tensor from torchmetrics import Accuracy -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer class LitBase(LightningModule): @@ -21,8 +21,7 @@ class LitBase(LightningModule): loss_fn: Type[nn.Module], optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], - mapping: EmnistMapping, - ignore_index: Optional[int] = None, + tokenizer: Tokenizer, ) -> None: super().__init__() @@ -30,8 +29,8 @@ class LitBase(LightningModule): self.loss_fn = loss_fn self.optimizer_config = optimizer_config self.lr_scheduler_config = lr_scheduler_config - self.mapping = mapping - + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("<p>")) # Placeholders self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 2c74b7e..752f3eb 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,11 +1,12 @@ """Lightning model for base Transformers.""" +from collections.abc import Sequence from typing import Optional, Tuple, Type import torch from omegaconf import DictConfig from torch import nn, Tensor -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.models.base import LitBase from text_recognizer.models.metrics.cer import CharacterErrorRate from text_recognizer.models.metrics.wer import WordErrorRate @@ -19,33 +20,23 @@ class LitTransformer(LitBase): network: Type[nn.Module], loss_fn: Type[nn.Module], optimizer_config: DictConfig, - mapping: EmnistMapping, + tokenizer: Tokenizer, lr_scheduler_config: Optional[DictConfig] = None, max_output_len: int = 682, - start_token: str = "<s>", - end_token: str = "<e>", - pad_token: str = "<p>", ) -> None: - self.max_output_len = max_output_len - self.start_token = start_token - self.end_token = end_token - self.pad_token = pad_token - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) - self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) - self.val_cer = CharacterErrorRate(self.ignore_indices) - self.test_cer = CharacterErrorRate(self.ignore_indices) - self.val_wer = WordErrorRate(self.ignore_indices) - self.test_wer = WordErrorRate(self.ignore_indices) super().__init__( network, loss_fn, optimizer_config, lr_scheduler_config, - mapping, - self.pad_index, + tokenizer, ) + self.max_output_len = max_output_len + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) + self.val_cer = CharacterErrorRate(self.ignore_indices) + self.test_cer = CharacterErrorRate(self.ignore_indices) + self.val_wer = WordErrorRate(self.ignore_indices) + self.test_wer = WordErrorRate(self.ignore_indices) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -63,11 +54,12 @@ class LitTransformer(LitBase): """Validation step.""" data, targets = batch preds = self.predict(data) - self.val_acc(preds, targets) + pred_text, target_text = self.get_text(preds, targets) + self.val_acc(pred_text, target_text) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) + self.val_cer(pred_text, target_text) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - self.val_wer(preds, targets) + self.val_wer(pred_text, target_text) self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -75,14 +67,22 @@ class LitTransformer(LitBase): data, targets = batch # Compute the text prediction. - pred = self(data) - self.test_acc(pred, targets) + preds = self(data) + pred_text, target_text = self.get_text(preds, targets) + self.test_acc(pred_text, target_text) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - self.test_cer(pred, targets) + self.test_cer(pred_text, target_text) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) - self.test_wer(pred, targets) + self.test_wer(pred_text, target_text) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) + def get_text( + self, preds: Tensor, targets: Tensor + ) -> Tuple[Sequence[str], Sequence[str]]: + pred_text = [self.tokenizer.decode(p) for p in preds] + target_text = [self.tokenizer.decode(t) for t in targets] + return pred_text, target_text + @torch.no_grad() def predict(self, x: Tensor) -> Tensor: """Predicts text in image. @@ -97,6 +97,9 @@ class LitTransformer(LitBase): Returns: Tensor: A tensor of token indices of the predictions from the model. """ + start_index = self.tokenizer.start_index + end_index = self.tokenizer.start_index + pad_index = self.tokenizer.start_index bsz = x.shape[0] # Encode image(s) to latent vectors. @@ -104,7 +107,7 @@ class LitTransformer(LitBase): # Create a placeholder matrix for storing outputs from the network output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index + output[:, 0] = start_index for Sy in range(1, self.max_output_len): context = output[:, :Sy] # (B, Sy) @@ -114,16 +117,13 @@ class LitTransformer(LitBase): # Early stopping of prediction loop if token is end or padding token. if ( - (output[:, Sy - 1] == self.end_index) - | (output[:, Sy - 1] == self.pad_index) + (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index) ).all(): break # Set all tokens after end token to pad token. for Sy in range(1, self.max_output_len): - idx = (output[:, Sy - 1] == self.end_index) | ( - output[:, Sy - 1] == self.pad_index - ) - output[idx, Sy] = self.pad_index + idx = (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index) + output[idx, Sy] = pad_index return output |