diff options
Diffstat (limited to 'src/text_recognizer/networks/crnn.py')
-rw-r--r-- | src/text_recognizer/networks/crnn.py | 40 |
1 files changed, 27 insertions, 13 deletions
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py index 3e605e2..9747429 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/src/text_recognizer/networks/crnn.py @@ -1,12 +1,9 @@ """LSTM with CTC for handwritten text recognition within a line.""" -import importlib -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Tuple from einops import rearrange, reduce -from einops.layers.torch import Rearrange, Reduce +from einops.layers.torch import Rearrange from loguru import logger -import torch from torch import nn from torch import Tensor @@ -28,16 +25,21 @@ class ConvolutionalRecurrentNetwork(nn.Module): patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), recurrent_cell: str = "lstm", + avg_pool: bool = False, + use_sliding_window: bool = True, ) -> None: super().__init__() self.backbone_args = backbone_args or {} self.patch_size = patch_size self.stride = stride - self.sliding_window = self._configure_sliding_window() + self.sliding_window = ( + self._configure_sliding_window() if use_sliding_window else None + ) self.input_size = input_size self.hidden_size = hidden_size self.backbone = configure_backbone(backbone, backbone_args) self.bidirectional = bidirectional + self.avg_pool = avg_pool if recurrent_cell.upper() in ["LSTM", "GRU"]: recurrent_cell = getattr(nn, recurrent_cell) @@ -76,15 +78,27 @@ class ConvolutionalRecurrentNetwork(nn.Module): """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" if len(x.shape) < 4: x = x[(None,) * (4 - len(x.shape))] - x = self.sliding_window(x) - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.backbone(x) + if self.sliding_window is not None: + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - # Avgerage pooling. - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + x = self.backbone(x) + + # Avgerage pooling. + if self.avg_pool: + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + else: + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) + else: + # Encode the entire image with a CNN, and use the channels as temporal dimension. + b = x.shape[0] + x = self.backbone(x) + x = rearrange(x, "b c h w -> c b (h w)", b=b) # Sequence predictions. x, _ = self.rnn(x) |