diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer/networks/lenet.py | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer/networks/lenet.py')
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 19 |
1 files changed, 5 insertions, 14 deletions
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 2839a0c..cbc58fc 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,24 +1,16 @@ """Defines the LeNet network.""" from typing import Callable, Dict, Optional, Tuple +from einops.layers.torch import Rearrange import torch from torch import nn -class Flatten(nn.Module): - """Flattens a tensor.""" - - def forward(self, x: int) -> torch.Tensor: - """Flattens a tensor for input to a nn.Linear layer.""" - return torch.flatten(x, start_dim=1) - - class LeNet(nn.Module): """LeNet network.""" def __init__( self, - input_size: Tuple[int, ...] = (1, 28, 28), channels: Tuple[int, ...] = (1, 32, 64), kernel_sizes: Tuple[int, ...] = (3, 3, 2), hidden_size: Tuple[int, ...] = (9216, 128), @@ -30,7 +22,6 @@ class LeNet(nn.Module): """The LeNet network. Args: - input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28). channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. @@ -44,10 +35,9 @@ class LeNet(nn.Module): """ super().__init__() - self._input_size = input_size - if activation_fn is not None: - activation_fn = getattr(nn, activation_fn)(activation_fn_args) + activation_fn_args = activation_fn_args or {} + activation_fn = getattr(nn, activation_fn)(**activation_fn_args) else: activation_fn = nn.ReLU(inplace=True) @@ -66,7 +56,7 @@ class LeNet(nn.Module): activation_fn, nn.MaxPool2d(kernel_sizes[2]), nn.Dropout(p=dropout_rate), - Flatten(), + Rearrange("b c h w -> b (c h w)"), nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), activation_fn, nn.Dropout(p=dropout_rate), @@ -77,6 +67,7 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" + # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) return self.layers(x) |