diff options
Diffstat (limited to 'src/text_recognizer/networks/line_lstm_ctc.py')
-rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 120 |
1 files changed, 0 insertions, 120 deletions
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py deleted file mode 100644 index 9009f94..0000000 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ /dev/null @@ -1,120 +0,0 @@ -"""LSTM with CTC for handwritten text recognition within a line.""" -import importlib -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange, Reduce -from loguru import logger -import torch -from torch import nn -from torch import Tensor - - -class LineRecurrentNetwork(nn.Module): - """Network that takes a image of a text line and predicts tokens that are in the image.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict = None, - flatten: bool = True, - input_size: int = 128, - hidden_size: int = 128, - bidirectional: bool = False, - num_layers: int = 1, - num_classes: int = 80, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - ) -> None: - super().__init__() - self.backbone_args = backbone_args or {} - self.patch_size = patch_size - self.stride = stride - self.sliding_window = self._configure_sliding_window() - self.input_size = input_size - self.hidden_size = hidden_size - self.backbone = self._configure_backbone(backbone) - self.bidirectional = bidirectional - self.flatten = flatten - - if self.flatten: - self.fc = nn.Linear( - in_features=self.input_size, out_features=self.hidden_size - ) - - self.rnn = nn.LSTM( - input_size=self.hidden_size, - hidden_size=self.hidden_size, - bidirectional=bidirectional, - num_layers=num_layers, - ) - - decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - - self.decoder = nn.Sequential( - nn.Linear(in_features=decoder_size, out_features=num_classes), - nn.LogSoftmax(dim=2), - ) - - def _configure_backbone(self, backbone: str) -> Type[nn.Module]: - network_module = importlib.import_module("text_recognizer.networks") - backbone_ = getattr(network_module, backbone) - - if "pretrained" in self.backbone_args: - logger.info("Loading pretrained backbone.") - checkpoint_file = Path(__file__).resolve().parents[ - 2 - ] / self.backbone_args.pop("pretrained") - - # Loading state directory. - state_dict = torch.load(checkpoint_file) - network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - # Initializes the network with trained weights. - backbone = backbone_(**network_args) - backbone.load_state_dict(weights) - if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True: - for params in backbone.parameters(): - params.requires_grad = False - - return backbone - else: - return backbone_(**self.backbone_args) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor) -> Tensor: - """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) == 3: - x = x.unsqueeze(0) - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.backbone(x) - - # Avgerage pooling. - x = ( - self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)) - if self.flatten - else rearrange(x, "(b t) h -> t b h", b=b, t=t) - ) - - # Sequence predictions. - x, _ = self.rnn(x) - - # Sequence to classifcation layer. - x = self.decoder(x) - return x |