summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/line_lstm_ctc.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
commite181195a699d7fa237f256d90ab4dedffc03d405 (patch)
tree6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/text_recognizer/networks/line_lstm_ctc.py
parent3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff)
Minor bug fixes etc.
Diffstat (limited to 'src/text_recognizer/networks/line_lstm_ctc.py')
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py63
1 files changed, 49 insertions, 14 deletions
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 5c57479..9009f94 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -1,9 +1,11 @@
"""LSTM with CTC for handwritten text recognition within a line."""
import importlib
+from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from einops import rearrange, reduce
from einops.layers.torch import Rearrange, Reduce
+from loguru import logger
import torch
from torch import nn
from torch import Tensor
@@ -14,40 +16,72 @@ class LineRecurrentNetwork(nn.Module):
def __init__(
self,
- encoder: str,
- encoder_args: Dict = None,
+ backbone: str,
+ backbone_args: Dict = None,
flatten: bool = True,
input_size: int = 128,
hidden_size: int = 128,
+ bidirectional: bool = False,
num_layers: int = 1,
num_classes: int = 80,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
) -> None:
super().__init__()
- self.encoder_args = encoder_args or {}
+ self.backbone_args = backbone_args or {}
self.patch_size = patch_size
self.stride = stride
self.sliding_window = self._configure_sliding_window()
self.input_size = input_size
self.hidden_size = hidden_size
- self.encoder = self._configure_encoder(encoder)
+ self.backbone = self._configure_backbone(backbone)
+ self.bidirectional = bidirectional
self.flatten = flatten
- self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size)
+
+ if self.flatten:
+ self.fc = nn.Linear(
+ in_features=self.input_size, out_features=self.hidden_size
+ )
+
self.rnn = nn.LSTM(
input_size=self.hidden_size,
hidden_size=self.hidden_size,
+ bidirectional=bidirectional,
num_layers=num_layers,
)
+
+ decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
self.decoder = nn.Sequential(
- nn.Linear(in_features=self.hidden_size, out_features=num_classes),
+ nn.Linear(in_features=decoder_size, out_features=num_classes),
nn.LogSoftmax(dim=2),
)
- def _configure_encoder(self, encoder: str) -> Type[nn.Module]:
+ def _configure_backbone(self, backbone: str) -> Type[nn.Module]:
network_module = importlib.import_module("text_recognizer.networks")
- encoder_ = getattr(network_module, encoder)
- return encoder_(**self.encoder_args)
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in self.backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[
+ 2
+ ] / self.backbone_args.pop("pretrained")
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True:
+ for params in backbone.parameters():
+ params.requires_grad = False
+
+ return backbone
+ else:
+ return backbone_(**self.backbone_args)
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
@@ -69,13 +103,14 @@ class LineRecurrentNetwork(nn.Module):
# Rearrange from a sequence of patches for feedforward network.
b, t = x.shape[:2]
x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.encoder(x)
+ x = self.backbone(x)
# Avgerage pooling.
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
-
- # Linear layer between CNN and RNN
- x = self.fc(x)
+ x = (
+ self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t))
+ if self.flatten
+ else rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ )
# Sequence predictions.
x, _ = self.rnn(x)