diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
commit | f473456c19558aaf8552df97a51d4e18cc69dfa8 (patch) | |
tree | 0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/text_recognizer/networks | |
parent | ad3bd52530f4800d4fb05dfef3354921f95513af (diff) |
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 55 | ||||
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 71 |
3 files changed, 57 insertions, 73 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 4ea5bb3..e6b6946 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1 +1,5 @@ """Network modules.""" +from .lenet import LeNet +from .mlp import MLP + +__all__ = ["MLP", "LeNet"] diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 71d247f..2839a0c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,5 +1,5 @@ """Defines the LeNet network.""" -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import torch from torch import nn @@ -18,28 +18,37 @@ class LeNet(nn.Module): def __init__( self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - hidden_size: Tuple[int, ...], - dropout_rate: float, - output_size: int, + input_size: Tuple[int, ...] = (1, 28, 28), + channels: Tuple[int, ...] = (1, 32, 64), + kernel_sizes: Tuple[int, ...] = (3, 3, 2), + hidden_size: Tuple[int, ...] = (9216, 128), + dropout_rate: float = 0.2, + output_size: int = 10, activation_fn: Optional[Callable] = None, + activation_fn_args: Optional[Dict] = None, ) -> None: """The LeNet network. Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. + input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28). + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - dropout_rate (float): The dropout rate. - output_size (int): Number of classes. + Defaults to (9216, 128). + dropout_rate (float): The dropout rate. Defaults to 0.2. + output_size (int): Number of classes. Defaults to 10. activation_fn (Optional[Callable]): The non-linear activation function. Defaults to nn.ReLU(inplace). + activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. """ super().__init__() - if activation_fn is None: + self._input_size = input_size + + if activation_fn is not None: + activation_fn = getattr(nn, activation_fn)(activation_fn_args) + else: activation_fn = nn.ReLU(inplace=True) self.layers = [ @@ -68,26 +77,6 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) return self.layers(x) - - -# def test(): -# x = torch.randn([1, 1, 28, 28]) -# channels = [1, 32, 64] -# kernel_sizes = [3, 3, 2] -# hidden_size = [9216, 128] -# output_size = 10 -# dropout_rate = 0.2 -# activation_fn = nn.ReLU() -# net = LeNet( -# channels=channels, -# kernel_sizes=kernel_sizes, -# dropout_rate=dropout_rate, -# hidden_size=hidden_size, -# output_size=output_size, -# activation_fn=activation_fn, -# ) -# from torchsummary import summary -# -# summary(net, (1, 28, 28), device="cpu") -# out = net(x) diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index 2a41790..d704d99 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,5 +1,5 @@ """Defines the MLP network.""" -from typing import Callable, Optional +from typing import Callable, Dict, List, Optional, Union import torch from torch import nn @@ -10,45 +10,54 @@ class MLP(nn.Module): def __init__( self, - input_size: int, - output_size: int, - hidden_size: int, - num_layers: int, - dropout_rate: float, + input_size: int = 784, + output_size: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + dropout_rate: float = 0.2, activation_fn: Optional[Callable] = None, + activation_fn_args: Optional[Dict] = None, ) -> None: """Initialization of the MLP network. Args: - input_size (int): The input shape of the network. - output_size (int): Number of classes in the dataset. - hidden_size (int): The number of `neurons` in each hidden layer. - num_layers (int): The number of hidden layers. - dropout_rate (float): The dropout rate at each layer. - activation_fn (Optional[Callable]): The activation function in the hidden layers, (default: - nn.ReLU()). + input_size (int): The input shape of the network. Defaults to 784. + output_size (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. + activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to + None. + activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. """ super().__init__() - if activation_fn is None: + if activation_fn is not None: + activation_fn = getattr(nn, activation_fn)(activation_fn_args) + else: activation_fn = nn.ReLU(inplace=True) + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + self.layers = [ - nn.Linear(in_features=input_size, out_features=hidden_size), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), activation_fn, ] - for _ in range(num_layers): + for i in range(num_layers - 1): self.layers += [ - nn.Linear(in_features=hidden_size, out_features=hidden_size), + nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), activation_fn, ] if dropout_rate: self.layers.append(nn.Dropout(p=dropout_rate)) - self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size)) + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=output_size) + ) self.layers = nn.Sequential(*self.layers) @@ -57,25 +66,7 @@ class MLP(nn.Module): x = torch.flatten(x, start_dim=1) return self.layers(x) - -# def test(): -# x = torch.randn([1, 28, 28]) -# input_size = torch.flatten(x).shape[0] -# output_size = 10 -# hidden_size = 128 -# num_layers = 5 -# dropout_rate = 0.25 -# activation_fn = nn.GELU() -# net = MLP( -# input_size=input_size, -# output_size=output_size, -# hidden_size=hidden_size, -# num_layers=num_layers, -# dropout_rate=dropout_rate, -# activation_fn=activation_fn, -# ) -# from torchsummary import summary -# -# summary(net, (1, 28, 28), device="cpu") -# -# out = net(x) + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" |