summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/crnn.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/networks/crnn.py
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/networks/crnn.py')
-rw-r--r--src/text_recognizer/networks/crnn.py40
1 files changed, 27 insertions, 13 deletions
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
index 3e605e2..9747429 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -1,12 +1,9 @@
"""LSTM with CTC for handwritten text recognition within a line."""
-import importlib
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Dict, Tuple
from einops import rearrange, reduce
-from einops.layers.torch import Rearrange, Reduce
+from einops.layers.torch import Rearrange
from loguru import logger
-import torch
from torch import nn
from torch import Tensor
@@ -28,16 +25,21 @@ class ConvolutionalRecurrentNetwork(nn.Module):
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
) -> None:
super().__init__()
self.backbone_args = backbone_args or {}
self.patch_size = patch_size
self.stride = stride
- self.sliding_window = self._configure_sliding_window()
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
self.input_size = input_size
self.hidden_size = hidden_size
self.backbone = configure_backbone(backbone, backbone_args)
self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
if recurrent_cell.upper() in ["LSTM", "GRU"]:
recurrent_cell = getattr(nn, recurrent_cell)
@@ -76,15 +78,27 @@ class ConvolutionalRecurrentNetwork(nn.Module):
"""Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
- x = self.sliding_window(x)
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.backbone(x)
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- # Avgerage pooling.
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ x = self.backbone(x)
+
+ # Avgerage pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ b = x.shape[0]
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> c b (h w)", b=b)
# Sequence predictions.
x, _ = self.rnn(x)