summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/criterions/label_smoothing.py38
-rw-r--r--text_recognizer/data/base_data_module.py6
-rw-r--r--text_recognizer/data/base_mapping.py37
-rw-r--r--text_recognizer/data/download_utils.py2
-rw-r--r--text_recognizer/data/emnist_mapping.py37
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py3
-rw-r--r--text_recognizer/data/iam_lines.py2
-rw-r--r--text_recognizer/data/iam_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
-rw-r--r--text_recognizer/data/make_wordpieces.py2
-rw-r--r--text_recognizer/data/mappings.py156
-rw-r--r--text_recognizer/data/transforms.py8
-rw-r--r--text_recognizer/data/word_piece_mapping.py93
-rw-r--r--text_recognizer/models/base.py20
-rw-r--r--text_recognizer/models/transformer.py36
-rw-r--r--text_recognizer/networks/conv_transformer.py42
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py9
-rw-r--r--text_recognizer/networks/transformer/layers.py27
18 files changed, 259 insertions, 275 deletions
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py
index 40a7609..cc71c45 100644
--- a/text_recognizer/criterions/label_smoothing.py
+++ b/text_recognizer/criterions/label_smoothing.py
@@ -6,37 +6,31 @@ import torch.nn.functional as F
class LabelSmoothingLoss(nn.Module):
- """Label smoothing cross entropy loss."""
-
- def __init__(
- self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
- ) -> None:
- assert 0.0 < label_smoothing <= 1.0
- self.ignore_index = ignore_index
+ def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1):
super().__init__()
+ assert 0.0 < smoothing <= 1.0
+ self.ignore_index = ignore_index
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.dim = dim
- smoothing_value = label_smoothing / (vocab_size - 2)
- one_hot = torch.full((vocab_size,), smoothing_value)
- one_hot[self.ignore_index] = 0
- self.register_buffer("one_hot", one_hot.unsqueeze(0))
-
- self.confidence = 1.0 - label_smoothing
-
- def forward(self, output: Tensor, targets: Tensor) -> Tensor:
+ def forward(self, output: Tensor, target: Tensor) -> Tensor:
"""Computes the loss.
Args:
- output (Tensor): Predictions from the network.
+ output (Tensor): outputictions from the network.
targets (Tensor): Ground truth.
Shapes:
- outpus: Batch size x num classes
- targets: Batch size
+ TBC
Returns:
Tensor: Label smoothing loss.
"""
- model_prob = self.one_hot.repeat(targets.size(0), 1)
- model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
- model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
- return F.kl_div(output, model_prob, reduction="sum")
+ output = output.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ true_dist = torch.zeros_like(output)
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
+ true_dist += self.smoothing / output.size(self.dim)
+ return torch.mean(torch.sum(-true_dist * output, dim=self.dim))
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index fd914b6..16a06d9 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,12 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Tuple
+from typing import Dict, Tuple, Type
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
-from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -25,7 +25,7 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- mapping: AbstractMapping = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
pin_memory: bool = attr.ib(default=True)
diff --git a/text_recognizer/data/base_mapping.py b/text_recognizer/data/base_mapping.py
new file mode 100644
index 0000000..572ac95
--- /dev/null
+++ b/text_recognizer/data/base_mapping.py
@@ -0,0 +1,37 @@
+"""Mapping to and from word pieces."""
+from abc import ABC, abstractmethod
+from typing import Dict, List
+
+from torch import Tensor
+
+
+class AbstractMapping(ABC):
+ def __init__(
+ self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int]
+ ) -> None:
+ self.input_size = input_size
+ self.mapping = mapping
+ self.inverse_mapping = inverse_mapping
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
+ @abstractmethod
+ def get_token(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_index(self, *args, **kwargs) -> Tensor:
+ ...
+
+ @abstractmethod
+ def get_text(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_indices(self, *args, **kwargs) -> Tensor:
+ ...
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
index 8938830..a5a5360 100644
--- a/text_recognizer/data/download_utils.py
+++ b/text_recognizer/data/download_utils.py
@@ -1,7 +1,7 @@
"""Util functions for downloading datasets."""
import hashlib
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import Dict, Optional
from urllib.request import urlretrieve
from loguru import logger as log
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
new file mode 100644
index 0000000..6c4c43b
--- /dev/null
+++ b/text_recognizer/data/emnist_mapping.py
@@ -0,0 +1,37 @@
+"""Emnist mapping."""
+from typing import List, Optional, Union, Set
+
+from torch import Tensor
+
+from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.emnist import emnist_mapping
+
+
+class EmnistMapping(AbstractMapping):
+ def __init__(self, extra_symbols: Optional[Set[str]] = None) -> None:
+ self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ self.extra_symbols
+ )
+ super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) in self.mapping:
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.inverse_mapping:
+ return Tensor(self.inverse_mapping[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ return Tensor([self.inverse_mapping[token] for token in text])
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index ccf0759..df0c0e1 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,6 +1,4 @@
"""IAM original and sythetic dataset class."""
-from typing import Dict, List
-
import attr
from torch.utils.data import ConcatDataset
@@ -15,7 +13,6 @@ class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
- num_classes: int = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 1c63729..aba38f9 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -22,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 6189f7d..11f899f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -17,7 +17,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -50,11 +50,9 @@ class IAMParagraphs(BaseDataModule):
if PROCESSED_DATA_DIRNAME.exists():
return
- log.info(
- "Cropping IAM paragraph regions and saving them along with labels..."
- )
+ log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
properties = {}
@@ -83,7 +81,9 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
+ strings=labels,
+ mapping=self.mapping.inverse_mapping,
+ length=self.output_dims[0],
)
return BaseDataset(
data,
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index c938f8b..24ca896 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -21,7 +21,7 @@ from text_recognizer.data.iam_paragraphs import (
IMAGE_SCALE_FACTOR,
resize_image,
)
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -47,7 +47,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
log.info("Preparing IAM lines for synthetic paragraphs dataset.")
log.info("Cropping IAM line regions and loading labels.")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index 40fbee4..8e53815 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -13,8 +13,6 @@ import click
from loguru import logger as log
import sentencepiece as spm
-from text_recognizer.data.iam_preprocessor import load_metadata
-
def iamdb_pieces(
data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
deleted file mode 100644
index d1c64dd..0000000
--- a/text_recognizer/data/mappings.py
+++ /dev/null
@@ -1,156 +0,0 @@
-"""Mapping to and from word pieces."""
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Dict, List, Optional, Union, Set
-
-import attr
-import torch
-from loguru import logger as log
-from torch import Tensor
-
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.iam_preprocessor import Preprocessor
-
-
-@attr.s
-class AbstractMapping(ABC):
- input_size: List[int] = attr.ib(init=False)
- mapping: List[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __len__(self) -> int:
- return len(self.mapping)
-
- @property
- def num_classes(self) -> int:
- return self.__len__()
-
- @abstractmethod
- def get_token(self, *args, **kwargs) -> str:
- ...
-
- @abstractmethod
- def get_index(self, *args, **kwargs) -> Tensor:
- ...
-
- @abstractmethod
- def get_text(self, *args, **kwargs) -> str:
- ...
-
- @abstractmethod
- def get_indices(self, *args, **kwargs) -> Tensor:
- ...
-
-
-@attr.s(auto_attribs=True)
-class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
- self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- self.extra_symbols
- )
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) in self.mapping:
- return self.mapping[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- if token in self.inverse_mapping:
- return Tensor(self.inverse_mapping[token])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return "".join([self.mapping[index] for index in indices])
-
- def get_indices(self, text: str) -> Tensor:
- return Tensor([self.inverse_mapping[token] for token in text])
-
-
-@attr.s(auto_attribs=True)
-class WordPieceMapping(EmnistMapping):
- data_dir: Optional[Path] = attr.ib(default=None)
- num_features: int = attr.ib(default=1000)
- tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt")
- lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt")
- use_words: bool = attr.ib(default=False)
- prepend_wordsep: bool = attr.ib(default=False)
- special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set)
- extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set)
- wordpiece_processor: Preprocessor = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- super().__attrs_post_init__()
- self.data_dir = (
- (
- Path(__file__).resolve().parents[2]
- / "data"
- / "downloaded"
- / "iam"
- / "iamdb"
- )
- if self.data_dir is None
- else Path(self.data_dir)
- )
- log.debug(f"Using data dir: {self.data_dir}")
- if not self.data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
-
- processed_path = (
- Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
- )
-
- tokens_path = processed_path / self.tokens
- lexicon_path = processed_path / self.lexicon
-
- special_tokens = self.special_tokens
- if self.extra_symbols is not None:
- special_tokens = special_tokens | self.extra_symbols
-
- self.wordpiece_processor = Preprocessor(
- data_dir=self.data_dir,
- num_features=self.num_features,
- tokens_path=tokens_path,
- lexicon_path=lexicon_path,
- use_words=self.use_words,
- prepend_wordsep=self.prepend_wordsep,
- special_tokens=special_tokens,
- )
-
- def __len__(self) -> int:
- return len(self.wordpiece_processor.tokens)
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) <= self.wordpiece_processor.num_tokens:
- return self.wordpiece_processor.tokens[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- if token in self.wordpiece_processor.tokens:
- return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return self.wordpiece_processor.to_text(indices)
-
- def get_indices(self, text: str) -> Tensor:
- return self.wordpiece_processor.to_index(text)
-
- def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
- text = "".join([self.mapping[i] for i in x])
- text = text.lower().replace(" ", "▁")
- return torch.LongTensor(self.wordpiece_processor.to_index(text))
-
- def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
- if isinstance(x, int):
- x = [x]
- if isinstance(x, str):
- return self.get_indices(x)
- return self.get_text(x)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 3b1b929..047496f 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,11 +1,11 @@
"""Transforms for PyTorch datasets."""
from pathlib import Path
-from typing import Optional, Union, Sequence
+from typing import Optional, Union, Set
import torch
from torch import Tensor
-from text_recognizer.data.mappings import WordPieceMapping
+from text_recognizer.data.word_piece_mapping import WordPieceMapping
class WordPiece:
@@ -19,8 +19,8 @@ class WordPiece:
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Optional[Set[str]] = {"\n",},
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
new file mode 100644
index 0000000..59488c3
--- /dev/null
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -0,0 +1,93 @@
+"""Word piece mapping."""
+from pathlib import Path
+from typing import List, Optional, Union, Set
+
+import torch
+from loguru import logger as log
+from torch import Tensor
+
+from text_recognizer.data.emnist_mapping import EmnistMapping
+from text_recognizer.data.iam_preprocessor import Preprocessor
+
+
+class WordPieceMapping(EmnistMapping):
+ def __init__(
+ self,
+ data_dir: Optional[Path] = None,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Set[str] = {"\n",},
+ ) -> None:
+ super().__init__(extra_symbols=extra_symbols)
+ self.data_dir = (
+ (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
+ if data_dir is None
+ else Path(data_dir)
+ )
+ log.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
+
+ processed_path = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ special_tokens = set(special_tokens)
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | set(extra_symbols)
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=use_words,
+ prepend_wordsep=prepend_wordsep,
+ special_tokens=special_tokens,
+ )
+
+ def __len__(self) -> int:
+ return len(self.wordpiece_processor.tokens)
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) <= self.wordpiece_processor.num_tokens:
+ return self.wordpiece_processor.tokens[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.wordpiece_processor.tokens:
+ return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return self.wordpiece_processor.to_text(indices).replace(" ", "▁")
+
+ def get_indices(self, text: str) -> Tensor:
+ return self.wordpiece_processor.to_index(text)
+
+ def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ text = "".join([self.mapping[i] for i in x])
+ text = text.lower().replace(" ", "▁")
+ return torch.LongTensor(self.wordpiece_processor.to_index(text))
+
+ def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, int):
+ x = [x]
+ if isinstance(x, str):
+ return self.get_indices(x)
+ return self.get_text(x)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8ce5c37..57c5964 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,6 +11,8 @@ from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.data.base_mapping import AbstractMapping
+
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -20,12 +22,12 @@ class BaseLitModel(LightningModule):
super().__init__()
network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ mapping: Type[AbstractMapping] = attr.ib()
+ loss_fn: Type[nn.Module] = attr.ib()
+ optimizer_config: DictConfig = attr.ib()
+ lr_scheduler_config: DictConfig = attr.ib()
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn: Type[nn.Module] = attr.ib(init=False)
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -36,12 +38,6 @@ class BaseLitModel(LightningModule):
init=False, default=torchmetrics.Accuracy()
)
- @loss_fn.default
- def configure_criterion(self) -> Type[nn.Module]:
- """Returns a loss functions."""
- log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
- return hydra.utils.instantiate(self.criterion_config)
-
def optimizer_zero_grad(
self,
epoch: int,
@@ -54,7 +50,9 @@ class BaseLitModel(LightningModule):
def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
def _configure_lr_scheduler(
self, optimizer: Type[torch.optim.Optimizer]
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 91e088d..5fb84a7 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -5,7 +5,6 @@ import attr
import torch
from torch import Tensor
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -14,14 +13,14 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib(default=None)
+ max_output_len: int = attr.ib(default=451)
start_token: str = attr.ib(default="<s>")
end_token: str = attr.ib(default="<e>")
pad_token: str = attr.ib(default="<p>")
- start_index: Tensor = attr.ib(init=False)
- end_index: Tensor = attr.ib(init=False)
- pad_index: Tensor = attr.ib(init=False)
+ start_index: int = attr.ib(init=False)
+ end_index: int = attr.ib(init=False)
+ pad_index: int = attr.ib(init=False)
ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
@@ -29,9 +28,9 @@ class TransformerLitModel(BaseLitModel):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
@@ -93,23 +92,24 @@ class TransformerLitModel(BaseLitModel):
output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
output[:, 0] = self.start_index
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.network.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
+ for Sy in range(1, self.max_output_len):
+ context = output[:, :Sy] # (B, Sy)
+ logits = self.network.decode(z, context) # (B, Sy, C)
+ tokens = torch.argmax(logits, dim=-1) # (B, Sy)
+ output[:, Sy : Sy + 1] = tokens[:, -1:]
# Early stopping of prediction loop if token is end or padding token.
if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ (output[:, Sy - 1] == self.end_index)
+ | (output[:, Sy - 1] == self.pad_index)
).all():
break
# Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ for Sy in range(1, self.max_output_len):
+ idx = (output[:, Sy - 1] == self.end_index) | (
+ output[:, Sy - 1] == self.pad_index
)
- output[idx, i] = self.pad_index
+ output[idx, Sy] = self.pad_index
return output
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 09cc654..f3ba49d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -2,7 +2,6 @@
import math
from typing import Tuple
-import attr
from torch import nn, Tensor
from text_recognizer.networks.encoders.efficientnet import EfficientNet
@@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dim = hidden_dim
+ self.dropout_rate = dropout_rate
+ self.num_classes = num_classes
+ self.pad_index = pad_index
+ self.encoder = encoder
+ self.decoder = decoder
- # Parameters and placeholders,
- input_dims: Tuple[int, int, int] = attr.ib()
- hidden_dim: int = attr.ib()
- dropout_rate: float = attr.ib()
- max_output_len: int = attr.ib()
- num_classes: int = attr.ib()
- pad_index: Tensor = attr.ib()
-
- # Modules.
- encoder: EfficientNet = attr.ib()
- decoder: Decoder = attr.ib()
-
- latent_encoder: nn.Sequential = attr.ib(init=False)
- token_embedding: nn.Embedding = attr.ib(init=False)
- token_pos_encoder: PositionalEncoding = attr.ib(init=False)
- head: nn.Linear = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -126,7 +121,8 @@ class ConvTransformer(nn.Module):
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)
- logits = self.head(out)
+ logits = self.head(out) # [B, Sy, T]
+ logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits
def forward(self, x: Tensor, context: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index e85df87..7bfd9ba 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -11,9 +11,7 @@ from text_recognizer.networks.encoders.efficientnet.utils import stochastic_dept
def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
"""Converts int to tuple."""
- return (
- (stride,) * 2 if isinstance(stride, int) else stride
- )
+ return (stride,) * 2 if isinstance(stride, int) else stride
@attr.s(eq=False)
@@ -41,10 +39,7 @@ class MBConvBlock(nn.Module):
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return (
- (self.kernel_size - 1) // 2 - 1,
- (self.kernel_size - 1) // 2,
- ) * 2
+ return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index ce443e5..70a0ac7 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,5 +1,4 @@
"""Transformer attention layer."""
-from functools import partial
from typing import Any, Dict, Optional, Tuple
import attr
@@ -27,25 +26,17 @@ class AttentionLayers(nn.Module):
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
ff_kwargs: Dict = attr.ib()
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib()
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
- attn: partial = attr.ib(init=False)
- norm: partial = attr.ib(init=False)
- ff: partial = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.layer_types = self._get_layer_types() * self.depth
- attn = load_partial_fn(
- self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
- )
- norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
- ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
- self.layers = self._build_network(attn, norm, ff)
+ self.layers = self._build_network()
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -53,10 +44,13 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _build_network(
- self, attn: partial, norm: partial, ff: partial,
- ) -> nn.ModuleList:
+ def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
+ attn = load_partial_fn(
+ self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+ )
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
+ ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
@@ -106,6 +100,7 @@ class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
- causal: bool = attr.ib(default=True, init=False)
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on decoder"
+ super().__init__(causal=True, **kwargs)