summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
commit7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch)
tree8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer
parent92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff)
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/__init__.py6
-rw-r--r--text_recognizer/data/base_data_module.py2
-rw-r--r--text_recognizer/data/emnist_lines.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py2
-rw-r--r--text_recognizer/data/iam_lines.py2
-rw-r--r--text_recognizer/data/iam_paragraphs.py2
-rw-r--r--text_recognizer/data/iam_preprocessor.py8
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/data/mappings.py111
-rw-r--r--text_recognizer/data/transforms.py16
-rw-r--r--text_recognizer/models/base.py4
-rw-r--r--text_recognizer/models/metrics.py8
-rw-r--r--text_recognizer/models/transformer.py80
-rw-r--r--text_recognizer/models/vqvae.py5
-rw-r--r--text_recognizer/networks/__init__.py4
-rw-r--r--text_recognizer/networks/base.py18
-rw-r--r--text_recognizer/networks/conv_transformer.py69
-rw-r--r--text_recognizer/networks/transformer/attention.py2
-rw-r--r--text_recognizer/networks/transformer/layers.py16
-rw-r--r--text_recognizer/networks/transformer/norm.py8
-rw-r--r--text_recognizer/networks/util.py4
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py1
22 files changed, 165 insertions, 207 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py
index 3599a8b..2727b20 100644
--- a/text_recognizer/data/__init__.py
+++ b/text_recognizer/data/__init__.py
@@ -1,7 +1 @@
"""Dataset modules."""
-from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset
-from .base_data_module import BaseDataModule, load_and_print_info
-from .download_utils import download_dataset
-from .iam_paragraphs import IAMParagraphs
-from .iam_synthetic_paragraphs import IAMSyntheticParagraphs
-from .iam_extended_paragraphs import IAMExtendedParagraphs
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 18b1996..408ae36 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -17,7 +17,7 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@attr.s
+@attr.s(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 4747508..7548ad5 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -32,7 +32,7 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 58c7369..23e424d 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -10,7 +10,7 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 13dd379..b7f3fdd 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -37,7 +37,7 @@ IMAGE_WIDTH = 1024
MAX_LABEL_LENGTH = 89
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index de32875..82058e0 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -34,7 +34,7 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
MAX_LABEL_LENGTH = 682
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index f7457e4..93a13bb 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -9,7 +9,7 @@ import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Sequence
import click
from loguru import logger
@@ -57,15 +57,13 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[List[str]] = None,
+ special_tokens: Optional[Sequence[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
self.special_tokens = special_tokens if special_tokens is not None else None
-
self.data_dir = Path(data_dir)
-
self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
# Load the set of graphemes:
@@ -123,7 +121,7 @@ class Preprocessor:
self.text.append(example["text"].lower())
def _to_index(self, line: str) -> torch.LongTensor:
- if line in self.special_tokens:
+ if self.special_tokens is not None and line in self.special_tokens:
return torch.LongTensor([self.tokens_to_index[line]])
token_to_index = self.graphemes_to_index
if self.lexicon is not None:
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index a3697e7..f00a494 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = (
)
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index 0d778b2..a934fd9 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,8 +1,9 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import List, Optional, Union, Sequence
+from typing import Dict, List, Optional, Union, Set, Sequence
+import attr
from loguru import logger
import torch
from torch import Tensor
@@ -29,10 +30,17 @@ class AbstractMapping(ABC):
...
+@attr.s
class EmnistMapping(AbstractMapping):
- def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None:
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
+ mapping: Sequence[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+ input_size: List[int] = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- extra_symbols
+ self.extra_symbols
)
def get_token(self, index: Union[int, Tensor]) -> str:
@@ -54,42 +62,21 @@ class EmnistMapping(AbstractMapping):
return Tensor([self.inverse_mapping[token] for token in text])
+@attr.s(auto_attribs=True)
class WordPieceMapping(EmnistMapping):
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- data_dir: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
- ) -> None:
- super().__init__(extra_symbols)
- self.wordpiece_processor = self._configure_wordpiece_processor(
- num_features,
- tokens,
- lexicon,
- data_dir,
- use_words,
- prepend_wordsep,
- special_tokens,
- extra_symbols,
- )
-
- @staticmethod
- def _configure_wordpiece_processor(
- num_features: int,
- tokens: str,
- lexicon: str,
- data_dir: Optional[Union[str, Path]],
- use_words: bool,
- prepend_wordsep: bool,
- special_tokens: Optional[Sequence[str]],
- extra_symbols: Optional[Sequence[str]],
- ) -> Preprocessor:
- data_dir = (
+ data_dir: Optional[Path] = attr.ib(default=None)
+ num_features: int = attr.ib(default=1000)
+ tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt")
+ lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt")
+ use_words: bool = attr.ib(default=False)
+ prepend_wordsep: bool = attr.ib(default=False)
+ special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set)
+ extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set)
+ wordpiece_processor: Preprocessor = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ super().__attrs_post_init__()
+ self.data_dir = (
(
Path(__file__).resolve().parents[2]
/ "data"
@@ -97,32 +84,32 @@ class WordPieceMapping(EmnistMapping):
/ "iam"
/ "iamdb"
)
- if data_dir is None
- else Path(data_dir)
+ if self.data_dir is None
+ else Path(self.data_dir)
)
-
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ logger.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
processed_path = (
Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if extra_symbols is not None:
- special_tokens += extra_symbols
-
- return Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
- special_tokens,
+ tokens_path = processed_path / self.tokens
+ lexicon_path = processed_path / self.lexicon
+
+ special_tokens = self.special_tokens
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | self.extra_symbols
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=self.num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=self.use_words,
+ prepend_wordsep=self.prepend_wordsep,
+ special_tokens=special_tokens,
)
def __len__(self) -> int:
@@ -151,7 +138,9 @@ class WordPieceMapping(EmnistMapping):
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
- def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]:
+ def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, int):
+ x = [x]
if isinstance(x, str):
- return self.get_index(x)
- return self.get_token(x)
+ return self.get_indices(x)
+ return self.get_text(x)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 66531a5..3b1b929 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -24,14 +24,14 @@ class WordPiece:
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
- num_features,
- tokens,
- lexicon,
- data_dir,
- use_words,
- prepend_wordsep,
- special_tokens,
- extra_symbols,
+ data_dir=data_dir,
+ num_features=num_features,
+ tokens=tokens,
+ lexicon=lexicon,
+ use_words=use_words,
+ prepend_wordsep=prepend_wordsep,
+ special_tokens=special_tokens,
+ extra_symbols=extra_symbols,
)
self.max_len = max_len
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 3e02261..dfb4ca4 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,8 +11,6 @@ from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer.networks.base import BaseNetwork
-
@attr.s
class BaseLitModel(LightningModule):
@@ -21,7 +19,7 @@ class BaseLitModel(LightningModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- network: Type[BaseNetwork] = attr.ib()
+ network: Type[nn.Module] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 4117ae2..9793157 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,5 +1,5 @@
"""Character Error Rate (CER)."""
-from typing import Set, Sequence
+from typing import Set
import attr
import editdistance
@@ -12,7 +12,7 @@ from torchmetrics import Metric
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_tokens: Set = attr.ib(converter=set)
+ ignore_indices: Set = attr.ib(converter=set)
error: Tensor = attr.ib(init=False)
total: Tensor = attr.ib(init=False)
@@ -25,8 +25,8 @@ class CharacterErrorRate(Metric):
"""Update CER."""
bsz = preds.shape[0]
for index in range(bsz):
- pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
- target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+ pred = [p for p in preds[index].tolist() if p not in self.ignore_indices]
+ target = [t for t in targets[index].tolist() if t not in self.ignore_indices]
distance = editdistance.distance(pred, target)
error = distance / max(len(pred), len(target))
self.error += error
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index f5cb491..7a9d566 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,11 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
+from typing import Sequence, Tuple, Type
import attr
-import hydra
-from omegaconf import DictConfig
-from torch import nn, Tensor
+import torch
+from torch import Tensor
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -13,18 +13,31 @@ from text_recognizer.models.base import BaseLitModel
@attr.s(auto_attribs=True)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
+ mapping: Type[AbstractMapping] = attr.ib()
+ start_token: str = attr.ib()
+ end_token: str = attr.ib()
+ pad_token: str = attr.ib()
- ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
+ start_index: Tensor = attr.ib(init=False)
+ end_index: Tensor = attr.ib(init=False)
+ pad_index: Tensor = attr.ib(init=False)
+
+ ignore_indices: Sequence[str] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.val_cer = CharacterErrorRate(self.ignore_tokens)
- self.test_cer = CharacterErrorRate(self.ignore_tokens)
+ """Post init configuration."""
+ self.start_index = self.mapping.get_index(self.start_token)
+ self.end_index = self.mapping.get_index(self.end_token)
+ self.pad_index = self.mapping.get_index(self.pad_token)
+ self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.val_cer = CharacterErrorRate(self.ignore_indices)
+ self.test_cer = CharacterErrorRate(self.ignore_indices)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
- return self.network.predict(data)
+ return self.predict(data)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
@@ -38,17 +51,64 @@ class TransformerLitModel(BaseLitModel):
"""Validation step."""
data, targets = batch
+ # Compute the loss.
logits = self.network(data, targets[:-1])
loss = self.loss_fn(logits, targets[1:])
self.log("val/loss", loss, prog_bar=True)
- pred = self.network.predict(data)
+ # Get the token prediction.
+ pred = self(data)
self.val_cer(pred, targets)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
- pred = self.network.predict(data)
+
+ # Compute the text prediction.
+ pred = self(data)
self.test_cer(pred, targets)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def predict(self, x: Tensor) -> Tensor:
+ """Predicts text in image.
+
+ Args:
+ x (Tensor): Image(s) to extract text from.
+
+ Shapes:
+ - x: :math: `(B, H, W)`
+ - output: :math: `(B, S)`
+
+ Returns:
+ Tensor: A tensor of token indices of the predictions from the model.
+ """
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ z = self.network.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ output[:, 0] = self.start_index
+
+ for i in range(1, self.max_output_len):
+ context = output[:, :i] # (bsz, i)
+ logits = self.network.decode(z, context) # (i, bsz, c)
+ tokens = torch.argmax(logits, dim=-1) # (i, bsz)
+ output[:, i : i + 1] = tokens[-1:]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (
+ output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ )
+ output[idx, i] = self.pad_index
+
+ return output
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 0172163..e215e14 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -34,8 +34,6 @@ class VQVAELitModel(BaseLitModel):
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
self.log("val/loss", loss, prog_bar=True)
- title = "val_pred_examples"
- self._log_prediction(data, reconstructions, title)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
@@ -43,5 +41,4 @@ class VQVAELitModel(BaseLitModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- title = "test_pred_examples"
- self._log_prediction(data, reconstructions, title)
+ self.log("test/loss", loss)
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 618450f..d9ef58b 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,5 +1 @@
"""Network modules"""
-# from .encoders import EfficientNet
-from .vqvae import VQVAE
-
-# from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
deleted file mode 100644
index 07b6a32..0000000
--- a/text_recognizer/networks/base.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""Base network with required methods."""
-from abc import abstractmethod
-
-import attr
-from torch import nn, Tensor
-
-
-@attr.s
-class BaseNetwork(nn.Module):
- """Base network."""
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- @abstractmethod
- def predict(self, x: Tensor) -> Tensor:
- """Return token indices for predictions."""
- ...
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 4acdc36..7371be4 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,13 +1,10 @@
"""Vision transformer for character recognition."""
import math
-from typing import Tuple, Type
+from typing import Tuple
import attr
-import torch
from torch import nn, Tensor
-from text_recognizer.data.mappings import AbstractMapping
-from text_recognizer.networks.base import BaseNetwork
from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
@@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(auto_attribs=True)
-class ConvTransformer(BaseNetwork):
+@attr.s
+class ConvTransformer(nn.Module):
+ """Convolutional encoder and transformer decoder network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
# Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
- start_token: str = attr.ib()
- start_index: Tensor = attr.ib(init=False)
- end_token: str = attr.ib()
- end_index: Tensor = attr.ib(init=False)
- pad_token: str = attr.ib()
- pad_index: Tensor = attr.ib(init=False)
+ pad_index: Tensor = attr.ib()
# Modules.
encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
latent_encoder: nn.Sequential = attr.ib(init=False)
token_embedding: nn.Embedding = attr.ib(init=False)
@@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
-
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork):
z = self.encode(x)
logits = self.decode(z, context)
return logits
-
- def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- z = self.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
-
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
- )
- output[idx, i] = self.pad_index
-
- return output
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 2770dc1..9202cce 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -24,9 +24,9 @@ class Attention(nn.Module):
dim: int = attr.ib()
num_heads: int = attr.ib()
+ causal: bool = attr.ib(default=False)
dim_head: int = attr.ib(default=64)
dropout_rate: float = attr.ib(default=0.0)
- casual: bool = attr.ib(default=False)
scale: float = attr.ib(init=False)
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 9b2f236..66c9c50 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -30,8 +30,7 @@ class AttentionLayers(nn.Module):
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False)
- has_pos_emb: bool = attr.ib(init=False)
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
attn: partial = attr.ib(init=False)
@@ -40,12 +39,11 @@ class AttentionLayers(nn.Module):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.has_pos_emb = True if self.rotary_emb is not None else False
self.layer_types = self._get_layer_types() * self.depth
attn = load_partial_fn(
self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
)
- norm = load_partial_fn(self.norm_fn, dim=self.dim)
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
self.layers = self._build_network(attn, norm, ff)
@@ -103,13 +101,11 @@ class AttentionLayers(nn.Module):
return x
+@attr.s(auto_attribs=True)
class Encoder(AttentionLayers):
- def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on encoder"
- super().__init__(causal=False, **kwargs)
+ causal: bool = attr.ib(default=False, init=False)
+@attr.s(auto_attribs=True)
class Decoder(AttentionLayers):
- def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on decoder"
- super().__init__(causal=True, **kwargs)
+ causal: bool = attr.ib(default=True, init=False)
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 8bc3221..4930adf 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -12,9 +12,9 @@ from torch import Tensor
class ScaleNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
+ def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None:
super().__init__()
- self.scale = dim ** -0.5
+ self.scale = normalized_shape ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -24,9 +24,9 @@ class ScaleNorm(nn.Module):
class PreNorm(nn.Module):
- def __init__(self, dim: int, fn: Type[nn.Module]) -> None:
+ def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
super().__init__()
- self.norm = nn.LayerNorm(dim)
+ self.norm = nn.LayerNorm(normalized_shape)
self.fn = fn
def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index e822c57..c94e8dc 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -24,6 +24,6 @@ def activation_function(activation: str) -> Type[nn.Module]:
def load_partial_fn(fn: str, **kwargs: Any) -> partial:
- """Loads partial function."""
+ """Loads partial function/class."""
module = import_module(".".join(fn.split(".")[:-1]))
- return partial(getattr(module, fn.split(".")[0]), **kwargs)
+ return partial(getattr(module, fn.split(".")[-1]), **kwargs)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 1f08e5e..5aa929b 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,5 +1,4 @@
"""The VQ-VAE."""
-
from typing import Any, Dict, List, Optional, Tuple
from torch import nn