summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/mapping.py8
-rw-r--r--text_recognizer/models/base.py1
-rw-r--r--text_recognizer/models/transformer.py5
-rw-r--r--text_recognizer/networks/image_transformer.py23
4 files changed, 17 insertions, 20 deletions
diff --git a/text_recognizer/data/mapping.py b/text_recognizer/data/mapping.py
new file mode 100644
index 0000000..f0edf3f
--- /dev/null
+++ b/text_recognizer/data/mapping.py
@@ -0,0 +1,8 @@
+"""Mapping to and from word pieces."""
+from pathlib import Path
+
+
+class WordPieces:
+
+ def __init__(self) -> None:
+ pass
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 0928e6c..c6d5d73 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -60,6 +60,7 @@ class LitBaseModel(pl.LightningModule):
scheduler["scheduler"] = getattr(
torch.optim.lr_scheduler, self._lr_scheduler.type
)(optimizer, **args)
+
return scheduler
def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index b23685b..7dc1352 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple
+from typing import Dict, List, Optional, Union, Tuple, Type
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
@@ -19,7 +19,7 @@ class LitTransformerModel(LitBaseModel):
def __init__(
self,
- network: Type[nn, Module],
+ network: Type[nn.Module],
optimizer: Union[DictConfig, Dict],
lr_scheduler: Union[DictConfig, Dict],
criterion: Union[DictConfig, Dict],
@@ -27,7 +27,6 @@ class LitTransformerModel(LitBaseModel):
mapping: Optional[List[str]] = None,
) -> None:
super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index 9ed67a4..daededa 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -10,16 +10,15 @@ TODO: Local attention for lower layer in attention.
"""
import importlib
import math
-from typing import Dict, List, Union, Sequence, Tuple, Type
+from typing import Dict, Optional, Union, Sequence, Type
from einops import rearrange
from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
-import torchvision
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
from text_recognizer.networks.transformer import (
Decoder,
DecoderLayer,
@@ -28,6 +27,8 @@ from text_recognizer.networks.transformer import (
target_padding_mask,
)
+NUM_WORD_PIECES = 1000
+
class ImageTransformer(nn.Module):
def __init__(
@@ -35,7 +36,7 @@ class ImageTransformer(nn.Module):
input_shape: Sequence[int],
output_shape: Sequence[int],
encoder: Union[DictConfig, Dict],
- mapping: str,
+ vocab_size: Optional[int] = None,
num_decoder_layers: int = 4,
hidden_dim: int = 256,
num_heads: int = 4,
@@ -43,14 +44,9 @@ class ImageTransformer(nn.Module):
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
) -> None:
- # Configure mapping
- mapping, inverse_mapping = self._configure_mapping(mapping)
- self.vocab_size = len(mapping)
+ self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
self.hidden_dim = hidden_dim
self.max_output_length = output_shape[0]
- self.start_index = inverse_mapping["<s>"]
- self.end_index = inverse_mapping["<e>"]
- self.pad_index = inverse_mapping["<p>"]
# Image backbone
self.encoder = self._configure_encoder(encoder)
@@ -107,13 +103,6 @@ class ImageTransformer(nn.Module):
encoder_class = getattr(network_module, encoder.type)
return encoder_class(**encoder.args)
- def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]:
- """Configures mapping."""
- # TODO: Fix me!!!
- if mapping == "emnist":
- mapping, inverse_mapping, _ = emnist_mapping()
- return mapping, inverse_mapping
-
def encode(self, image: Tensor) -> Tensor:
"""Extracts image features with backbone.