diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
commit | f473456c19558aaf8552df97a51d4e18cc69dfa8 (patch) | |
tree | 0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/text_recognizer/networks/lenet.py | |
parent | ad3bd52530f4800d4fb05dfef3354921f95513af (diff) |
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/text_recognizer/networks/lenet.py')
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 55 |
1 files changed, 22 insertions, 33 deletions
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 71d247f..2839a0c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,5 +1,5 @@ """Defines the LeNet network.""" -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import torch from torch import nn @@ -18,28 +18,37 @@ class LeNet(nn.Module): def __init__( self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - hidden_size: Tuple[int, ...], - dropout_rate: float, - output_size: int, + input_size: Tuple[int, ...] = (1, 28, 28), + channels: Tuple[int, ...] = (1, 32, 64), + kernel_sizes: Tuple[int, ...] = (3, 3, 2), + hidden_size: Tuple[int, ...] = (9216, 128), + dropout_rate: float = 0.2, + output_size: int = 10, activation_fn: Optional[Callable] = None, + activation_fn_args: Optional[Dict] = None, ) -> None: """The LeNet network. Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. + input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28). + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - dropout_rate (float): The dropout rate. - output_size (int): Number of classes. + Defaults to (9216, 128). + dropout_rate (float): The dropout rate. Defaults to 0.2. + output_size (int): Number of classes. Defaults to 10. activation_fn (Optional[Callable]): The non-linear activation function. Defaults to nn.ReLU(inplace). + activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. """ super().__init__() - if activation_fn is None: + self._input_size = input_size + + if activation_fn is not None: + activation_fn = getattr(nn, activation_fn)(activation_fn_args) + else: activation_fn = nn.ReLU(inplace=True) self.layers = [ @@ -68,26 +77,6 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) return self.layers(x) - - -# def test(): -# x = torch.randn([1, 1, 28, 28]) -# channels = [1, 32, 64] -# kernel_sizes = [3, 3, 2] -# hidden_size = [9216, 128] -# output_size = 10 -# dropout_rate = 0.2 -# activation_fn = nn.ReLU() -# net = LeNet( -# channels=channels, -# kernel_sizes=kernel_sizes, -# dropout_rate=dropout_rate, -# hidden_size=hidden_size, -# output_size=output_size, -# activation_fn=activation_fn, -# ) -# from torchsummary import summary -# -# summary(net, (1, 28, 28), device="cpu") -# out = net(x) |