diff options
Diffstat (limited to 'src/text_recognizer/networks')
| -rw-r--r-- | src/text_recognizer/networks/__init__.py | 2 | ||||
| -rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 63 | ||||
| -rw-r--r-- | src/text_recognizer/networks/losses.py | 31 | ||||
| -rw-r--r-- | src/text_recognizer/networks/residual_network.py | 3 | ||||
| -rw-r--r-- | src/text_recognizer/networks/transformer.py | 4 | 
5 files changed, 88 insertions, 15 deletions
| diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index d20c86a..a39975f 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -2,12 +2,14 @@  from .ctc import greedy_decoder  from .lenet import LeNet  from .line_lstm_ctc import LineRecurrentNetwork +from .losses import EmbeddingLoss  from .misc import sliding_window  from .mlp import MLP  from .residual_network import ResidualNetwork, ResidualNetworkEncoder  from .wide_resnet import WideResidualNetwork  __all__ = [ +    "EmbeddingLoss",      "greedy_decoder",      "MLP",      "LeNet", diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 5c57479..9009f94 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -1,9 +1,11 @@  """LSTM with CTC for handwritten text recognition within a line."""  import importlib +from pathlib import Path  from typing import Callable, Dict, List, Optional, Tuple, Type, Union  from einops import rearrange, reduce  from einops.layers.torch import Rearrange, Reduce +from loguru import logger  import torch  from torch import nn  from torch import Tensor @@ -14,40 +16,72 @@ class LineRecurrentNetwork(nn.Module):      def __init__(          self, -        encoder: str, -        encoder_args: Dict = None, +        backbone: str, +        backbone_args: Dict = None,          flatten: bool = True,          input_size: int = 128,          hidden_size: int = 128, +        bidirectional: bool = False,          num_layers: int = 1,          num_classes: int = 80,          patch_size: Tuple[int, int] = (28, 28),          stride: Tuple[int, int] = (1, 14),      ) -> None:          super().__init__() -        self.encoder_args = encoder_args or {} +        self.backbone_args = backbone_args or {}          self.patch_size = patch_size          self.stride = stride          self.sliding_window = self._configure_sliding_window()          self.input_size = input_size          self.hidden_size = hidden_size -        self.encoder = self._configure_encoder(encoder) +        self.backbone = self._configure_backbone(backbone) +        self.bidirectional = bidirectional          self.flatten = flatten -        self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size) + +        if self.flatten: +            self.fc = nn.Linear( +                in_features=self.input_size, out_features=self.hidden_size +            ) +          self.rnn = nn.LSTM(              input_size=self.hidden_size,              hidden_size=self.hidden_size, +            bidirectional=bidirectional,              num_layers=num_layers,          ) + +        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size +          self.decoder = nn.Sequential( -            nn.Linear(in_features=self.hidden_size, out_features=num_classes), +            nn.Linear(in_features=decoder_size, out_features=num_classes),              nn.LogSoftmax(dim=2),          ) -    def _configure_encoder(self, encoder: str) -> Type[nn.Module]: +    def _configure_backbone(self, backbone: str) -> Type[nn.Module]:          network_module = importlib.import_module("text_recognizer.networks") -        encoder_ = getattr(network_module, encoder) -        return encoder_(**self.encoder_args) +        backbone_ = getattr(network_module, backbone) + +        if "pretrained" in self.backbone_args: +            logger.info("Loading pretrained backbone.") +            checkpoint_file = Path(__file__).resolve().parents[ +                2 +            ] / self.backbone_args.pop("pretrained") + +            # Loading state directory. +            state_dict = torch.load(checkpoint_file) +            network_args = state_dict["network_args"] +            weights = state_dict["model_state"] + +            # Initializes the network with trained weights. +            backbone = backbone_(**network_args) +            backbone.load_state_dict(weights) +            if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True: +                for params in backbone.parameters(): +                    params.requires_grad = False + +            return backbone +        else: +            return backbone_(**self.backbone_args)      def _configure_sliding_window(self) -> nn.Sequential:          return nn.Sequential( @@ -69,13 +103,14 @@ class LineRecurrentNetwork(nn.Module):          # Rearrange from a sequence of patches for feedforward network.          b, t = x.shape[:2]          x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) -        x = self.encoder(x) +        x = self.backbone(x)          # Avgerage pooling. -        x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x - -        # Linear layer between CNN and RNN -        x = self.fc(x) +        x = ( +            self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)) +            if self.flatten +            else rearrange(x, "(b t) h -> t b h", b=b, t=t) +        )          # Sequence predictions.          x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py new file mode 100644 index 0000000..73e0641 --- /dev/null +++ b/src/text_recognizer/networks/losses.py @@ -0,0 +1,31 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +from torch import nn +from torch import Tensor + + +class EmbeddingLoss: +    """Metric loss for training encoders to produce information-rich latent embeddings.""" + +    def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: +        self.distance = distances.CosineSimilarity() +        self.reducer = reducers.ThresholdReducer(low=0) +        self.loss_fn = losses.TripletMarginLoss( +            margin=margin, distance=self.distance, reducer=self.reducer +        ) +        self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + +    def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: +        """Computes the metric loss for the embeddings based on their labels. + +        Args: +            embeddings (Tensor): The laten vectors encoded by the network. +            labels (Tensor): Labels of the embeddings. + +        Returns: +            Tensor: The metric loss for the embeddings. + +        """ +        hard_pairs = self.miner(embeddings, labels) +        loss = self.loss_fn(embeddings, labels, hard_pairs) +        return loss diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 1b5d6b3..046600d 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -278,7 +278,8 @@ class ResidualNetworkEncoder(nn.Module):          if self.stn is not None:              x = self.stn(x)          x = self.gate(x) -        return self.blocks(x) +        x = self.blocks(x) +        return x  class ResidualNetworkDecoder(nn.Module): diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py index 868d739..c091ba0 100644 --- a/src/text_recognizer/networks/transformer.py +++ b/src/text_recognizer/networks/transformer.py @@ -1 +1,5 @@  """TBC.""" +from typing import Dict + +import torch +from torch import Tensor |