summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-09 23:24:02 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-09 23:24:02 +0200
commit53677be4ec14854ea4881b0d78730e0414c8dedd (patch)
tree56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer/networks
parent125d5da5fb845d03bda91426e172bca7f537584a (diff)
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/ctc.py10
-rw-r--r--src/text_recognizer/networks/lenet.py19
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py4
-rw-r--r--src/text_recognizer/networks/misc.py28
-rw-r--r--src/text_recognizer/networks/mlp.py9
-rw-r--r--src/text_recognizer/networks/residual_network.py1
6 files changed, 55 insertions, 16 deletions
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
new file mode 100644
index 0000000..00ad47e
--- /dev/null
+++ b/src/text_recognizer/networks/ctc.py
@@ -0,0 +1,10 @@
+"""Decodes the CTC output."""
+#
+# from typing import Tuple
+# import torch
+#
+#
+# def greedy_decoder(
+# output, labels, label_length, blank_label, collapse_repeated=True
+# ) -> Tuple[torch.Tensor, torch.Tensor]:
+# pass
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 2839a0c..cbc58fc 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,24 +1,16 @@
"""Defines the LeNet network."""
from typing import Callable, Dict, Optional, Tuple
+from einops.layers.torch import Rearrange
import torch
from torch import nn
-class Flatten(nn.Module):
- """Flattens a tensor."""
-
- def forward(self, x: int) -> torch.Tensor:
- """Flattens a tensor for input to a nn.Linear layer."""
- return torch.flatten(x, start_dim=1)
-
-
class LeNet(nn.Module):
"""LeNet network."""
def __init__(
self,
- input_size: Tuple[int, ...] = (1, 28, 28),
channels: Tuple[int, ...] = (1, 32, 64),
kernel_sizes: Tuple[int, ...] = (3, 3, 2),
hidden_size: Tuple[int, ...] = (9216, 128),
@@ -30,7 +22,6 @@ class LeNet(nn.Module):
"""The LeNet network.
Args:
- input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28).
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
@@ -44,10 +35,9 @@ class LeNet(nn.Module):
"""
super().__init__()
- self._input_size = input_size
-
if activation_fn is not None:
- activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ activation_fn_args = activation_fn_args or {}
+ activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
else:
activation_fn = nn.ReLU(inplace=True)
@@ -66,7 +56,7 @@ class LeNet(nn.Module):
activation_fn,
nn.MaxPool2d(kernel_sizes[2]),
nn.Dropout(p=dropout_rate),
- Flatten(),
+ Rearrange("b c h w -> b (c h w)"),
nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
activation_fn,
nn.Dropout(p=dropout_rate),
@@ -77,6 +67,7 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward."""
+ # If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.layers(x)
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
new file mode 100644
index 0000000..d704139
--- /dev/null
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -0,0 +1,4 @@
+"""LSTM with CTC for handwritten text recognition within a line."""
+
+import torch
+from torch import nn
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
new file mode 100644
index 0000000..9440f9d
--- /dev/null
+++ b/src/text_recognizer/networks/misc.py
@@ -0,0 +1,28 @@
+"""Miscellaneous neural network functionality."""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch.nn import Unfold
+
+
+def sliding_window(
+ images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+ """Creates patches of an image.
+
+ Args:
+ images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+ patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+ stride (Tuple[int, int]): The stride of the sliding window.
+
+ Returns:
+ torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+ """
+ unfold = Unfold(kernel_size=patch_size, stride=stride)
+ patches = unfold(images)
+ patches = rearrange(
+ patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1]
+ )
+ return patches
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index d704d99..ac2c825 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -1,6 +1,7 @@
"""Defines the MLP network."""
from typing import Callable, Dict, List, Optional, Union
+from einops.layers.torch import Rearrange
import torch
from torch import nn
@@ -34,7 +35,8 @@ class MLP(nn.Module):
super().__init__()
if activation_fn is not None:
- activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ activation_fn_args = activation_fn_args or {}
+ activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
else:
activation_fn = nn.ReLU(inplace=True)
@@ -42,6 +44,7 @@ class MLP(nn.Module):
hidden_size = [hidden_size] * num_layers
self.layers = [
+ Rearrange("b c h w -> b (c h w)"),
nn.Linear(in_features=input_size, out_features=hidden_size[0]),
activation_fn,
]
@@ -63,7 +66,9 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward."""
- x = torch.flatten(x, start_dim=1)
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
return self.layers(x)
@property
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
new file mode 100644
index 0000000..23394b0
--- /dev/null
+++ b/src/text_recognizer/networks/residual_network.py
@@ -0,0 +1 @@
+"""Residual CNN."""