summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/base_data_module.py11
-rw-r--r--text_recognizer/data/emnist.py10
-rw-r--r--text_recognizer/data/emnist_essentials.json (renamed from text_recognizer/data/mappings/emnist_essentials.json)2
-rw-r--r--text_recognizer/data/emnist_lines.py16
-rw-r--r--text_recognizer/data/iam.py6
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_lines.py15
-rw-r--r--text_recognizer/data/iam_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py12
-rw-r--r--text_recognizer/data/mappings/__init__.py2
-rw-r--r--text_recognizer/data/tokenizer.py (renamed from text_recognizer/data/mappings/emnist.py)32
-rw-r--r--text_recognizer/metadata/shared.py3
-rw-r--r--text_recognizer/models/base.py9
-rw-r--r--text_recognizer/models/transformer.py66
14 files changed, 108 insertions, 100 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 7863333..bd6fd99 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from text_recognizer.data.base_dataset import BaseDataset
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
T = TypeVar("T")
@@ -24,7 +24,7 @@ class BaseDataModule(LightningDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -34,7 +34,7 @@ class BaseDataModule(LightningDataModule):
pin_memory: bool = True,
) -> None:
super().__init__()
- self.mapping = mapping
+ self.tokenizer = tokenizer
self.transform = transform
self.test_transform = test_transform
self.target_transform = target_transform
@@ -50,11 +50,6 @@ class BaseDataModule(LightningDataModule):
self.dims: Tuple[int, ...] = None
self.output_dims: Tuple[int, ...] = None
- @classmethod
- def data_dirname(cls: T) -> Path:
- """Return the path to the base data directory."""
- return Path(__file__).resolve().parents[2] / "data"
-
def config(self) -> Dict:
"""Return important settings of the dataset."""
return {
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 143705e..b5db075 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -13,7 +13,6 @@ from loguru import logger as log
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.download_utils import download_dataset
import text_recognizer.metadata.emnist as metadata
@@ -32,7 +31,7 @@ class EMNIST(BaseDataModule):
def __init__(self) -> None:
super().__init__()
- self.dims = (1, *self.mapping.input_size)
+ self.dims = (1, *self.tokenizer.input_size)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
@@ -65,8 +64,8 @@ class EMNIST(BaseDataModule):
"""Returns string with info about the dataset."""
basic = (
"EMNIST Dataset\n"
- f"Num classes: {len(self.mapping)}\n"
- f"Mapping: {self.mapping}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
+ f"Mapping: {self.tokenizer}\n"
f"Dims: {self.dims}\n"
)
if not any([self.data_train, self.data_val, self.data_test]):
@@ -193,5 +192,4 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
def download_emnist() -> None:
"""Download dataset from internet, if it does not exists, and displays info."""
- transform = load_transform_from_file("transform/default.yaml")
- load_and_print_info(EMNIST(transform=transform, test_transfrom=transform))
+ load_and_print_info(EMNIST())
diff --git a/text_recognizer/data/mappings/emnist_essentials.json b/text_recognizer/data/emnist_essentials.json
index c412425..956c28d 100644
--- a/text_recognizer/data/mappings/emnist_essentials.json
+++ b/text_recognizer/data/emnist_essentials.json
@@ -1 +1 @@
-{"characters": ["<b>", "<s>", "<e>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} \ No newline at end of file
+{"characters": ["<b>", "<s>", "<e>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 88aac0d..8a31c44 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -12,7 +12,7 @@ from torch import Tensor
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.line import LineStem
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
import text_recognizer.metadata.emnist_lines as metadata
@@ -23,7 +23,7 @@ class EMNISTLines(BaseDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -39,7 +39,7 @@ class EMNISTLines(BaseDataModule):
num_test: int = 2_000,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -120,7 +120,7 @@ class EMNISTLines(BaseDataModule):
"EMNISTLines2 Dataset\n" # pylint: disable=no-member
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
@@ -153,17 +153,17 @@ class EMNISTLines(BaseDataModule):
if split == "train":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, self.mapping.mapping
+ emnist.x_train, emnist.y_train, self.tokenizer.mapping
)
num = self.num_train
elif split == "val":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, self.mapping.mapping
+ emnist.x_train, emnist.y_train, self.tokenizer.mapping
)
num = self.num_val
else:
samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, self.mapping.mapping
+ emnist.x_test, emnist.y_test, self.tokenizer.mapping
)
num = self.num_test
@@ -178,7 +178,7 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH
+ y, self.tokenizer.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 2ce1e9c..8a31205 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -4,14 +4,14 @@ Which encompasses both paragraphs and lines, with associated utilities.
"""
import os
-import xml.etree.ElementTree as ElementTree
-import zipfile
from pathlib import Path
from typing import Any, Dict, List
+import xml.etree.ElementTree as ElementTree
+import zipfile
-import toml
from boltons.cacheutils import cachedproperty
from loguru import logger as log
+import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.utils.download_utils import download_dataset
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 658626c..c6628a8 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -7,7 +7,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
from text_recognizer.data.transforms.pad import Pad
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.paragraph import ParagraphStem
import text_recognizer.metadata.iam_paragraphs as metadata
@@ -17,7 +17,7 @@ class IAMExtendedParagraphs(BaseDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -27,7 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule):
pin_memory: bool = True,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -37,7 +37,7 @@ class IAMExtendedParagraphs(BaseDataModule):
pin_memory,
)
self.iam_paragraphs = IAMParagraphs(
- mapping=self.mapping,
+ tokenizer=self.tokenizer,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -46,7 +46,7 @@ class IAMExtendedParagraphs(BaseDataModule):
target_transform=self.target_transform,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- mapping=self.mapping,
+ tokenizer=self.tokenizer,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -78,7 +78,7 @@ class IAMExtendedParagraphs(BaseDataModule):
"""Returns info about the dataset."""
basic = (
"IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index e60d1ba..a0d9b59 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -19,8 +19,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.line import IamLinesStem
from text_recognizer.data.utils import image_utils
import text_recognizer.metadata.iam_lines as metadata
@@ -33,7 +32,7 @@ class IAMLines(BaseDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -43,7 +42,7 @@ class IAMLines(BaseDataModule):
pin_memory: bool = True,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -61,7 +60,7 @@ class IAMLines(BaseDataModule):
return
log.info("Cropping IAM lines regions...")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(tokenizer=self.tokenizer)
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
@@ -100,7 +99,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
- labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
+ labels_train, self.tokenizer.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
x_train,
@@ -122,7 +121,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
- labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
+ labels_test, self.tokenizer.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
x_test,
@@ -144,7 +143,7 @@ class IAMLines(BaseDataModule):
"""Return information about the dataset."""
basic = (
"IAM Lines dataset\n"
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Input dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index fe1f15c..a078c7d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -17,7 +17,7 @@ from text_recognizer.data.base_dataset import (
)
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms.pad import Pad
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.paragraph import ParagraphStem
import text_recognizer.metadata.iam_paragraphs as metadata
@@ -27,7 +27,7 @@ class IAMParagraphs(BaseDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -37,7 +37,7 @@ class IAMParagraphs(BaseDataModule):
pin_memory: bool = True,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -56,7 +56,7 @@ class IAMParagraphs(BaseDataModule):
log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping(extra_symbols={metadata.NEW_LINE_TOKEN}))
+ iam = IAM(tokenizer=self.tokenizer)
iam.prepare_data()
properties = {}
@@ -88,7 +88,7 @@ class IAMParagraphs(BaseDataModule):
data = [resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
strings=labels,
- mapping=self.mapping.inverse_mapping,
+ mapping=self.tokenizer.inverse_mapping,
length=self.output_dims[0],
)
return BaseDataset(
@@ -122,7 +122,7 @@ class IAMParagraphs(BaseDataModule):
"""Return information about the dataset."""
basic = (
"IAM Paragraphs Dataset\n"
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Input dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 91fda4a..511a8d4 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -19,7 +19,7 @@ from text_recognizer.data.iam_lines import (
load_line_crops_and_labels,
save_images_and_labels,
)
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.paragraph import ParagraphStem
from text_recognizer.data.transforms.pad import Pad
import text_recognizer.metadata.iam_synthetic_paragraphs as metadata
@@ -30,7 +30,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -40,7 +40,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
pin_memory: bool = True,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -58,7 +58,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
log.info("Preparing IAM lines for synthetic paragraphs dataset.")
log.info("Cropping IAM line regions and loading labels.")
- iam = IAM(mapping=EmnistMapping(extra_symbols=(metadata.NEW_LINE_TOKEN,)))
+ iam = IAM(tokenizer=self.tokenizer)
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
@@ -94,7 +94,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
targets = convert_strings_to_labels(
strings=paragraphs_labels,
- mapping=self.mapping.inverse_mapping,
+ mapping=self.tokenizer.inverse_mapping,
length=self.output_dims[0],
)
self.data_train = BaseDataset(
@@ -108,7 +108,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
"""Return information about the dataset."""
basic = (
"IAM Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Input dims : {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/mappings/__init__.py b/text_recognizer/data/mappings/__init__.py
deleted file mode 100644
index 635f506..0000000
--- a/text_recognizer/data/mappings/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Mapping modules."""
-from text_recognizer.data.mappings.emnist import EmnistMapping
diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/tokenizer.py
index 331976e..a5f44e6 100644
--- a/text_recognizer/data/mappings/emnist.py
+++ b/text_recognizer/data/tokenizer.py
@@ -6,19 +6,29 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+import text_recognizer.metadata.shared as metadata
-class EmnistMapping:
+class Tokenizer:
"""Mapping for EMNIST labels."""
def __init__(
self,
extra_symbols: Optional[Sequence[str]] = None,
lower: bool = True,
+ start_token: str = "<s>",
+ end_token: str = "<e>",
+ pad_token: str = "<p>",
) -> None:
self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = self._load_mapping()
+ self.start_token = start_token
+ self.end_token = end_token
+ self.pad_token = pad_token
+ self.start_index = int(self.get_value(self.start_token))
+ self.end_index = int(self.get_value(self.end_token))
+ self.pad_index = int(self.get_value(self.pad_token))
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
if lower:
self._to_lower()
@@ -31,7 +41,7 @@ class EmnistMapping:
def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
- with ESSENTIALS_FILENAME.open() as f:
+ with metadata.ESSENTIALS_FILENAME.open() as f:
essentials = json.load(f)
mapping = list(essentials["characters"])
if self.extra_symbols is not None:
@@ -57,19 +67,25 @@ class EmnistMapping:
return self.mapping[index]
raise KeyError(f"Index ({index}) not in mapping.")
- def get_index(self, token: str) -> Tensor:
+ def get_value(self, token: str) -> Tensor:
"""Returns index value of token."""
if token in self.inverse_mapping:
return torch.LongTensor([self.inverse_mapping[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ def decode(self, indices: Union[List[int], Tensor]) -> str:
"""Returns the text from a list of indices."""
if isinstance(indices, Tensor):
indices = indices.tolist()
- return "".join([self.mapping[index] for index in indices])
-
- def get_indices(self, text: str) -> Tensor:
+ return "".join(
+ [
+ self.mapping[index]
+ for index in indices
+ if index not in self.ignore_indices
+ ]
+ )
+
+ def encode(self, text: str) -> Tensor:
"""Returns tensor of indices for a string."""
return Tensor([self.inverse_mapping[token] for token in text])
diff --git a/text_recognizer/metadata/shared.py b/text_recognizer/metadata/shared.py
index a4d1da0..cee5de4 100644
--- a/text_recognizer/metadata/shared.py
+++ b/text_recognizer/metadata/shared.py
@@ -1,4 +1,7 @@
from pathlib import Path
+ESSENTIALS_FILENAME = (
+ Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
+)
DATA_DIRNAME = Path(__file__).resolve().parents[2] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloded"
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index bb4e695..f8f4b40 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -9,7 +9,7 @@ from pytorch_lightning import LightningModule
from torch import nn, Tensor
from torchmetrics import Accuracy
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
class LitBase(LightningModule):
@@ -21,8 +21,7 @@ class LitBase(LightningModule):
loss_fn: Type[nn.Module],
optimizer_config: DictConfig,
lr_scheduler_config: Optional[DictConfig],
- mapping: EmnistMapping,
- ignore_index: Optional[int] = None,
+ tokenizer: Tokenizer,
) -> None:
super().__init__()
@@ -30,8 +29,8 @@ class LitBase(LightningModule):
self.loss_fn = loss_fn
self.optimizer_config = optimizer_config
self.lr_scheduler_config = lr_scheduler_config
- self.mapping = mapping
-
+ self.tokenizer = tokenizer
+ ignore_index = int(self.tokenizer.get_value("<p>"))
# Placeholders
self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 2c74b7e..752f3eb 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,11 +1,12 @@
"""Lightning model for base Transformers."""
+from collections.abc import Sequence
from typing import Optional, Tuple, Type
import torch
from omegaconf import DictConfig
from torch import nn, Tensor
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.models.base import LitBase
from text_recognizer.models.metrics.cer import CharacterErrorRate
from text_recognizer.models.metrics.wer import WordErrorRate
@@ -19,33 +20,23 @@ class LitTransformer(LitBase):
network: Type[nn.Module],
loss_fn: Type[nn.Module],
optimizer_config: DictConfig,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
lr_scheduler_config: Optional[DictConfig] = None,
max_output_len: int = 682,
- start_token: str = "<s>",
- end_token: str = "<e>",
- pad_token: str = "<p>",
) -> None:
- self.max_output_len = max_output_len
- self.start_token = start_token
- self.end_token = end_token
- self.pad_token = pad_token
- self.start_index = int(self.mapping.get_index(self.start_token))
- self.end_index = int(self.mapping.get_index(self.end_token))
- self.pad_index = int(self.mapping.get_index(self.pad_token))
- self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
- self.val_cer = CharacterErrorRate(self.ignore_indices)
- self.test_cer = CharacterErrorRate(self.ignore_indices)
- self.val_wer = WordErrorRate(self.ignore_indices)
- self.test_wer = WordErrorRate(self.ignore_indices)
super().__init__(
network,
loss_fn,
optimizer_config,
lr_scheduler_config,
- mapping,
- self.pad_index,
+ tokenizer,
)
+ self.max_output_len = max_output_len
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
+ self.val_cer = CharacterErrorRate(self.ignore_indices)
+ self.test_cer = CharacterErrorRate(self.ignore_indices)
+ self.val_wer = WordErrorRate(self.ignore_indices)
+ self.test_wer = WordErrorRate(self.ignore_indices)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
@@ -63,11 +54,12 @@ class LitTransformer(LitBase):
"""Validation step."""
data, targets = batch
preds = self.predict(data)
- self.val_acc(preds, targets)
+ pred_text, target_text = self.get_text(preds, targets)
+ self.val_acc(pred_text, target_text)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
- self.val_cer(preds, targets)
+ self.val_cer(pred_text, target_text)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
- self.val_wer(preds, targets)
+ self.val_wer(pred_text, target_text)
self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -75,14 +67,22 @@ class LitTransformer(LitBase):
data, targets = batch
# Compute the text prediction.
- pred = self(data)
- self.test_acc(pred, targets)
+ preds = self(data)
+ pred_text, target_text = self.get_text(preds, targets)
+ self.test_acc(pred_text, target_text)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
- self.test_cer(pred, targets)
+ self.test_cer(pred_text, target_text)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
- self.test_wer(pred, targets)
+ self.test_wer(pred_text, target_text)
self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
+ def get_text(
+ self, preds: Tensor, targets: Tensor
+ ) -> Tuple[Sequence[str], Sequence[str]]:
+ pred_text = [self.tokenizer.decode(p) for p in preds]
+ target_text = [self.tokenizer.decode(t) for t in targets]
+ return pred_text, target_text
+
@torch.no_grad()
def predict(self, x: Tensor) -> Tensor:
"""Predicts text in image.
@@ -97,6 +97,9 @@ class LitTransformer(LitBase):
Returns:
Tensor: A tensor of token indices of the predictions from the model.
"""
+ start_index = self.tokenizer.start_index
+ end_index = self.tokenizer.start_index
+ pad_index = self.tokenizer.start_index
bsz = x.shape[0]
# Encode image(s) to latent vectors.
@@ -104,7 +107,7 @@ class LitTransformer(LitBase):
# Create a placeholder matrix for storing outputs from the network
output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
+ output[:, 0] = start_index
for Sy in range(1, self.max_output_len):
context = output[:, :Sy] # (B, Sy)
@@ -114,16 +117,13 @@ class LitTransformer(LitBase):
# Early stopping of prediction loop if token is end or padding token.
if (
- (output[:, Sy - 1] == self.end_index)
- | (output[:, Sy - 1] == self.pad_index)
+ (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index)
).all():
break
# Set all tokens after end token to pad token.
for Sy in range(1, self.max_output_len):
- idx = (output[:, Sy - 1] == self.end_index) | (
- output[:, Sy - 1] == self.pad_index
- )
- output[idx, Sy] = self.pad_index
+ idx = (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index)
+ output[idx, Sy] = pad_index
return output