summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/character_predictor.py1
-rw-r--r--src/text_recognizer/models/base.py56
-rw-r--r--src/text_recognizer/models/character_model.py6
-rw-r--r--src/text_recognizer/models/line_ctc_model.py8
-rw-r--r--src/text_recognizer/networks/__init__.py2
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py63
-rw-r--r--src/text_recognizer/networks/losses.py31
-rw-r--r--src/text_recognizer/networks/residual_network.py3
-rw-r--r--src/text_recognizer/networks/transformer.py4
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.ptbin0 -> 26090486 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.ptbin5003730 -> 32765213 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.ptbin45257014 -> 20694308 bytes
12 files changed, 128 insertions, 46 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py
index b733a53..df37e68 100644
--- a/src/text_recognizer/character_predictor.py
+++ b/src/text_recognizer/character_predictor.py
@@ -15,6 +15,7 @@ class CharacterPredictor:
"""Intializes the CharacterModel and load the pretrained weights."""
self.model = CharacterModel(network_fn=network_fn)
self.model.eval()
+ self.model.use_swa_model()
def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
"""Predict on a single images contianing a handwritten character."""
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index d23fe56..caf8065 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -77,9 +77,9 @@ class Model(ABC):
# Stochastic Weight Averaging placeholders.
self.swa_args = swa_args
- self._swa_start = None
self._swa_scheduler = None
self._swa_network = None
+ self._use_swa_model = False
# Experiment directory.
self.model_dir = None
@@ -220,15 +220,24 @@ class Model(ABC):
if self._optimizer and self._lr_scheduler is not None:
if "OneCycleLR" in str(self._lr_scheduler):
self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
- self._lr_scheduler = self._lr_scheduler(
- self._optimizer, **self.lr_scheduler_args
- )
- else:
- self._lr_scheduler = None
+
+ # Assume lr scheduler should update at each epoch if not specified.
+ if "interval" not in self.lr_scheduler_args:
+ interval = "epoch"
+ else:
+ interval = self.lr_scheduler_args.pop("interval")
+ self._lr_scheduler = {
+ "lr_scheduler": self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ ),
+ "interval": interval,
+ }
if self.swa_args is not None:
- self._swa_start = self.swa_args["start"]
- self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"])
+ self._swa_scheduler = {
+ "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
+ "swa_start": self.swa_args["start"],
+ }
self._swa_network = AveragedModel(self._network).to(self.device)
@property
@@ -280,21 +289,16 @@ class Model(ABC):
return self._optimizer
@property
- def lr_scheduler(self) -> Optional[Callable]:
- """Learning rate scheduler."""
+ def lr_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the learning rate scheduler."""
return self._lr_scheduler
@property
- def swa_scheduler(self) -> Optional[Callable]:
- """Returns the stochastic weight averaging scheduler."""
+ def swa_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the stochastic weight averaging scheduler."""
return self._swa_scheduler
@property
- def swa_start(self) -> Optional[Callable]:
- """Returns the start epoch of stochastic weight averaging."""
- return self._swa_start
-
- @property
def swa_network(self) -> Optional[Callable]:
"""Returns the stochastic weight averaging network."""
return self._swa_network
@@ -311,20 +315,32 @@ class Model(ABC):
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
+ def use_swa_model(self) -> None:
+ """Set to use predictions from SWA model."""
+ if self.swa_network is not None:
+ self._use_swa_model = True
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass with the network."""
+ if self._use_swa_model:
+ return self.swa_network(x)
+ else:
+ return self.network(x)
+
def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
"""Compute the loss."""
return self.criterion(output, targets)
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5
+ self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
) -> None:
"""Prints a summary of the network architecture."""
if input_shape is not None:
- summary(self._network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=self.device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self._network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=self.device)
else:
logger.warning("Could not print summary as input shape is not set.")
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 64ba693..50e94a2 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -75,11 +75,7 @@ class CharacterModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- logits = (
- self.swa_network(image)
- if self.swa_network is not None
- else self.network(image)
- )
+ logits = self.forward(image)
prediction = self.softmax(logits.squeeze(0))
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index af41f18..16eaed3 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -98,16 +98,12 @@ class LineCTCModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- log_probs = (
- self.swa_network(image)
- if self.swa_network is not None
- else self.network(image)
- )
+ log_probs = self.forward(image)
raw_pred, _ = greedy_decoder(
predictions=log_probs,
character_mapper=self.mapper,
- blank_label=80,
+ blank_label=79,
collapse_repeated=True,
)
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index d20c86a..a39975f 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -2,12 +2,14 @@
from .ctc import greedy_decoder
from .lenet import LeNet
from .line_lstm_ctc import LineRecurrentNetwork
+from .losses import EmbeddingLoss
from .misc import sliding_window
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .wide_resnet import WideResidualNetwork
__all__ = [
+ "EmbeddingLoss",
"greedy_decoder",
"MLP",
"LeNet",
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 5c57479..9009f94 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -1,9 +1,11 @@
"""LSTM with CTC for handwritten text recognition within a line."""
import importlib
+from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from einops import rearrange, reduce
from einops.layers.torch import Rearrange, Reduce
+from loguru import logger
import torch
from torch import nn
from torch import Tensor
@@ -14,40 +16,72 @@ class LineRecurrentNetwork(nn.Module):
def __init__(
self,
- encoder: str,
- encoder_args: Dict = None,
+ backbone: str,
+ backbone_args: Dict = None,
flatten: bool = True,
input_size: int = 128,
hidden_size: int = 128,
+ bidirectional: bool = False,
num_layers: int = 1,
num_classes: int = 80,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
) -> None:
super().__init__()
- self.encoder_args = encoder_args or {}
+ self.backbone_args = backbone_args or {}
self.patch_size = patch_size
self.stride = stride
self.sliding_window = self._configure_sliding_window()
self.input_size = input_size
self.hidden_size = hidden_size
- self.encoder = self._configure_encoder(encoder)
+ self.backbone = self._configure_backbone(backbone)
+ self.bidirectional = bidirectional
self.flatten = flatten
- self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size)
+
+ if self.flatten:
+ self.fc = nn.Linear(
+ in_features=self.input_size, out_features=self.hidden_size
+ )
+
self.rnn = nn.LSTM(
input_size=self.hidden_size,
hidden_size=self.hidden_size,
+ bidirectional=bidirectional,
num_layers=num_layers,
)
+
+ decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
self.decoder = nn.Sequential(
- nn.Linear(in_features=self.hidden_size, out_features=num_classes),
+ nn.Linear(in_features=decoder_size, out_features=num_classes),
nn.LogSoftmax(dim=2),
)
- def _configure_encoder(self, encoder: str) -> Type[nn.Module]:
+ def _configure_backbone(self, backbone: str) -> Type[nn.Module]:
network_module = importlib.import_module("text_recognizer.networks")
- encoder_ = getattr(network_module, encoder)
- return encoder_(**self.encoder_args)
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in self.backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[
+ 2
+ ] / self.backbone_args.pop("pretrained")
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True:
+ for params in backbone.parameters():
+ params.requires_grad = False
+
+ return backbone
+ else:
+ return backbone_(**self.backbone_args)
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
@@ -69,13 +103,14 @@ class LineRecurrentNetwork(nn.Module):
# Rearrange from a sequence of patches for feedforward network.
b, t = x.shape[:2]
x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.encoder(x)
+ x = self.backbone(x)
# Avgerage pooling.
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
-
- # Linear layer between CNN and RNN
- x = self.fc(x)
+ x = (
+ self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t))
+ if self.flatten
+ else rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ )
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py
new file mode 100644
index 0000000..73e0641
--- /dev/null
+++ b/src/text_recognizer/networks/losses.py
@@ -0,0 +1,31 @@
+"""Implementations of custom loss functions."""
+from pytorch_metric_learning import distances, losses, miners, reducers
+from torch import nn
+from torch import Tensor
+
+
+class EmbeddingLoss:
+ """Metric loss for training encoders to produce information-rich latent embeddings."""
+
+ def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
+ self.distance = distances.CosineSimilarity()
+ self.reducer = reducers.ThresholdReducer(low=0)
+ self.loss_fn = losses.TripletMarginLoss(
+ margin=margin, distance=self.distance, reducer=self.reducer
+ )
+ self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
+
+ def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
+ """Computes the metric loss for the embeddings based on their labels.
+
+ Args:
+ embeddings (Tensor): The laten vectors encoded by the network.
+ labels (Tensor): Labels of the embeddings.
+
+ Returns:
+ Tensor: The metric loss for the embeddings.
+
+ """
+ hard_pairs = self.miner(embeddings, labels)
+ loss = self.loss_fn(embeddings, labels, hard_pairs)
+ return loss
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 1b5d6b3..046600d 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -278,7 +278,8 @@ class ResidualNetworkEncoder(nn.Module):
if self.stn is not None:
x = self.stn(x)
x = self.gate(x)
- return self.blocks(x)
+ x = self.blocks(x)
+ return x
class ResidualNetworkDecoder(nn.Module):
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
index 868d739..c091ba0 100644
--- a/src/text_recognizer/networks/transformer.py
+++ b/src/text_recognizer/networks/transformer.py
@@ -1 +1,5 @@
"""TBC."""
+from typing import Dict
+
+import torch
+from torch import Tensor
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt
new file mode 100644
index 0000000..9f9deee
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
index a25bcd1..0dc7eb5 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
index 9bd8ca2..93d34d7 100644
--- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ