diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-20 00:14:27 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-20 00:14:27 +0200 |
commit | e181195a699d7fa237f256d90ab4dedffc03d405 (patch) | |
tree | 6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/text_recognizer | |
parent | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff) |
Minor bug fixes etc.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/character_predictor.py | 1 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 56 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/models/line_ctc_model.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 63 | ||||
-rw-r--r-- | src/text_recognizer/networks/losses.py | 31 | ||||
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 3 | ||||
-rw-r--r-- | src/text_recognizer/networks/transformer.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt | bin | 0 -> 26090486 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt | bin | 5003730 -> 32765213 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt | bin | 45257014 -> 20694308 bytes |
12 files changed, 128 insertions, 46 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py index b733a53..df37e68 100644 --- a/src/text_recognizer/character_predictor.py +++ b/src/text_recognizer/character_predictor.py @@ -15,6 +15,7 @@ class CharacterPredictor: """Intializes the CharacterModel and load the pretrained weights.""" self.model = CharacterModel(network_fn=network_fn) self.model.eval() + self.model.use_swa_model() def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: """Predict on a single images contianing a handwritten character.""" diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d23fe56..caf8065 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -77,9 +77,9 @@ class Model(ABC): # Stochastic Weight Averaging placeholders. self.swa_args = swa_args - self._swa_start = None self._swa_scheduler = None self._swa_network = None + self._use_swa_model = False # Experiment directory. self.model_dir = None @@ -220,15 +220,24 @@ class Model(ABC): if self._optimizer and self._lr_scheduler is not None: if "OneCycleLR" in str(self._lr_scheduler): self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - self._lr_scheduler = self._lr_scheduler( - self._optimizer, **self.lr_scheduler_args - ) - else: - self._lr_scheduler = None + + # Assume lr scheduler should update at each epoch if not specified. + if "interval" not in self.lr_scheduler_args: + interval = "epoch" + else: + interval = self.lr_scheduler_args.pop("interval") + self._lr_scheduler = { + "lr_scheduler": self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ), + "interval": interval, + } if self.swa_args is not None: - self._swa_start = self.swa_args["start"] - self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"]) + self._swa_scheduler = { + "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), + "swa_start": self.swa_args["start"], + } self._swa_network = AveragedModel(self._network).to(self.device) @property @@ -280,21 +289,16 @@ class Model(ABC): return self._optimizer @property - def lr_scheduler(self) -> Optional[Callable]: - """Learning rate scheduler.""" + def lr_scheduler(self) -> Optional[Dict]: + """Returns a directory with the learning rate scheduler.""" return self._lr_scheduler @property - def swa_scheduler(self) -> Optional[Callable]: - """Returns the stochastic weight averaging scheduler.""" + def swa_scheduler(self) -> Optional[Dict]: + """Returns a directory with the stochastic weight averaging scheduler.""" return self._swa_scheduler @property - def swa_start(self) -> Optional[Callable]: - """Returns the start epoch of stochastic weight averaging.""" - return self._swa_start - - @property def swa_network(self) -> Optional[Callable]: """Returns the stochastic weight averaging network.""" return self._swa_network @@ -311,20 +315,32 @@ class Model(ABC): WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") + def use_swa_model(self) -> None: + """Set to use predictions from SWA model.""" + if self.swa_network is not None: + self._use_swa_model = True + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass with the network.""" + if self._use_swa_model: + return self.swa_network(x) + else: + return self.network(x) + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: """Compute the loss.""" return self.criterion(output, targets) def summary( - self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5 + self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3 ) -> None: """Prints a summary of the network architecture.""" if input_shape is not None: - summary(self._network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=self.device) elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self._network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=self.device) else: logger.warning("Could not print summary as input shape is not set.") diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 64ba693..50e94a2 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -75,11 +75,7 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = ( - self.swa_network(image) - if self.swa_network is not None - else self.network(image) - ) + logits = self.forward(image) prediction = self.softmax(logits.squeeze(0)) diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py index af41f18..16eaed3 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/line_ctc_model.py @@ -98,16 +98,12 @@ class LineCTCModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - log_probs = ( - self.swa_network(image) - if self.swa_network is not None - else self.network(image) - ) + log_probs = self.forward(image) raw_pred, _ = greedy_decoder( predictions=log_probs, character_mapper=self.mapper, - blank_label=80, + blank_label=79, collapse_repeated=True, ) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index d20c86a..a39975f 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -2,12 +2,14 @@ from .ctc import greedy_decoder from .lenet import LeNet from .line_lstm_ctc import LineRecurrentNetwork +from .losses import EmbeddingLoss from .misc import sliding_window from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .wide_resnet import WideResidualNetwork __all__ = [ + "EmbeddingLoss", "greedy_decoder", "MLP", "LeNet", diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 5c57479..9009f94 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -1,9 +1,11 @@ """LSTM with CTC for handwritten text recognition within a line.""" import importlib +from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Type, Union from einops import rearrange, reduce from einops.layers.torch import Rearrange, Reduce +from loguru import logger import torch from torch import nn from torch import Tensor @@ -14,40 +16,72 @@ class LineRecurrentNetwork(nn.Module): def __init__( self, - encoder: str, - encoder_args: Dict = None, + backbone: str, + backbone_args: Dict = None, flatten: bool = True, input_size: int = 128, hidden_size: int = 128, + bidirectional: bool = False, num_layers: int = 1, num_classes: int = 80, patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), ) -> None: super().__init__() - self.encoder_args = encoder_args or {} + self.backbone_args = backbone_args or {} self.patch_size = patch_size self.stride = stride self.sliding_window = self._configure_sliding_window() self.input_size = input_size self.hidden_size = hidden_size - self.encoder = self._configure_encoder(encoder) + self.backbone = self._configure_backbone(backbone) + self.bidirectional = bidirectional self.flatten = flatten - self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size) + + if self.flatten: + self.fc = nn.Linear( + in_features=self.input_size, out_features=self.hidden_size + ) + self.rnn = nn.LSTM( input_size=self.hidden_size, hidden_size=self.hidden_size, + bidirectional=bidirectional, num_layers=num_layers, ) + + decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size + self.decoder = nn.Sequential( - nn.Linear(in_features=self.hidden_size, out_features=num_classes), + nn.Linear(in_features=decoder_size, out_features=num_classes), nn.LogSoftmax(dim=2), ) - def _configure_encoder(self, encoder: str) -> Type[nn.Module]: + def _configure_backbone(self, backbone: str) -> Type[nn.Module]: network_module = importlib.import_module("text_recognizer.networks") - encoder_ = getattr(network_module, encoder) - return encoder_(**self.encoder_args) + backbone_ = getattr(network_module, backbone) + + if "pretrained" in self.backbone_args: + logger.info("Loading pretrained backbone.") + checkpoint_file = Path(__file__).resolve().parents[ + 2 + ] / self.backbone_args.pop("pretrained") + + # Loading state directory. + state_dict = torch.load(checkpoint_file) + network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + backbone = backbone_(**network_args) + backbone.load_state_dict(weights) + if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True: + for params in backbone.parameters(): + params.requires_grad = False + + return backbone + else: + return backbone_(**self.backbone_args) def _configure_sliding_window(self) -> nn.Sequential: return nn.Sequential( @@ -69,13 +103,14 @@ class LineRecurrentNetwork(nn.Module): # Rearrange from a sequence of patches for feedforward network. b, t = x.shape[:2] x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.encoder(x) + x = self.backbone(x) # Avgerage pooling. - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x - - # Linear layer between CNN and RNN - x = self.fc(x) + x = ( + self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)) + if self.flatten + else rearrange(x, "(b t) h -> t b h", b=b, t=t) + ) # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py new file mode 100644 index 0000000..73e0641 --- /dev/null +++ b/src/text_recognizer/networks/losses.py @@ -0,0 +1,31 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +from torch import nn +from torch import Tensor + + +class EmbeddingLoss: + """Metric loss for training encoders to produce information-rich latent embeddings.""" + + def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: + self.distance = distances.CosineSimilarity() + self.reducer = reducers.ThresholdReducer(low=0) + self.loss_fn = losses.TripletMarginLoss( + margin=margin, distance=self.distance, reducer=self.reducer + ) + self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + + def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: + """Computes the metric loss for the embeddings based on their labels. + + Args: + embeddings (Tensor): The laten vectors encoded by the network. + labels (Tensor): Labels of the embeddings. + + Returns: + Tensor: The metric loss for the embeddings. + + """ + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_fn(embeddings, labels, hard_pairs) + return loss diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 1b5d6b3..046600d 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -278,7 +278,8 @@ class ResidualNetworkEncoder(nn.Module): if self.stn is not None: x = self.stn(x) x = self.gate(x) - return self.blocks(x) + x = self.blocks(x) + return x class ResidualNetworkDecoder(nn.Module): diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py index 868d739..c091ba0 100644 --- a/src/text_recognizer/networks/transformer.py +++ b/src/text_recognizer/networks/transformer.py @@ -1 +1,5 @@ """TBC.""" +from typing import Dict + +import torch +from torch import Tensor diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt Binary files differnew file mode 100644 index 0000000..9f9deee --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt Binary files differindex a25bcd1..0dc7eb5 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt Binary files differindex 9bd8ca2..93d34d7 100644 --- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt +++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt |