summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/lenet.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/lenet.py')
-rw-r--r--src/text_recognizer/networks/lenet.py19
1 files changed, 5 insertions, 14 deletions
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 2839a0c..cbc58fc 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,24 +1,16 @@
"""Defines the LeNet network."""
from typing import Callable, Dict, Optional, Tuple
+from einops.layers.torch import Rearrange
import torch
from torch import nn
-class Flatten(nn.Module):
- """Flattens a tensor."""
-
- def forward(self, x: int) -> torch.Tensor:
- """Flattens a tensor for input to a nn.Linear layer."""
- return torch.flatten(x, start_dim=1)
-
-
class LeNet(nn.Module):
"""LeNet network."""
def __init__(
self,
- input_size: Tuple[int, ...] = (1, 28, 28),
channels: Tuple[int, ...] = (1, 32, 64),
kernel_sizes: Tuple[int, ...] = (3, 3, 2),
hidden_size: Tuple[int, ...] = (9216, 128),
@@ -30,7 +22,6 @@ class LeNet(nn.Module):
"""The LeNet network.
Args:
- input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28).
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
@@ -44,10 +35,9 @@ class LeNet(nn.Module):
"""
super().__init__()
- self._input_size = input_size
-
if activation_fn is not None:
- activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ activation_fn_args = activation_fn_args or {}
+ activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
else:
activation_fn = nn.ReLU(inplace=True)
@@ -66,7 +56,7 @@ class LeNet(nn.Module):
activation_fn,
nn.MaxPool2d(kernel_sizes[2]),
nn.Dropout(p=dropout_rate),
- Flatten(),
+ Rearrange("b c h w -> b (c h w)"),
nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
activation_fn,
nn.Dropout(p=dropout_rate),
@@ -77,6 +67,7 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward."""
+ # If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
return self.layers(x)