diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
commit | f473456c19558aaf8552df97a51d4e18cc69dfa8 (patch) | |
tree | 0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/text_recognizer/networks/mlp.py | |
parent | ad3bd52530f4800d4fb05dfef3354921f95513af (diff) |
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/text_recognizer/networks/mlp.py')
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 71 |
1 files changed, 31 insertions, 40 deletions
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index 2a41790..d704d99 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,5 +1,5 @@ """Defines the MLP network.""" -from typing import Callable, Optional +from typing import Callable, Dict, List, Optional, Union import torch from torch import nn @@ -10,45 +10,54 @@ class MLP(nn.Module): def __init__( self, - input_size: int, - output_size: int, - hidden_size: int, - num_layers: int, - dropout_rate: float, + input_size: int = 784, + output_size: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + dropout_rate: float = 0.2, activation_fn: Optional[Callable] = None, + activation_fn_args: Optional[Dict] = None, ) -> None: """Initialization of the MLP network. Args: - input_size (int): The input shape of the network. - output_size (int): Number of classes in the dataset. - hidden_size (int): The number of `neurons` in each hidden layer. - num_layers (int): The number of hidden layers. - dropout_rate (float): The dropout rate at each layer. - activation_fn (Optional[Callable]): The activation function in the hidden layers, (default: - nn.ReLU()). + input_size (int): The input shape of the network. Defaults to 784. + output_size (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. + activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to + None. + activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. """ super().__init__() - if activation_fn is None: + if activation_fn is not None: + activation_fn = getattr(nn, activation_fn)(activation_fn_args) + else: activation_fn = nn.ReLU(inplace=True) + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + self.layers = [ - nn.Linear(in_features=input_size, out_features=hidden_size), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), activation_fn, ] - for _ in range(num_layers): + for i in range(num_layers - 1): self.layers += [ - nn.Linear(in_features=hidden_size, out_features=hidden_size), + nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), activation_fn, ] if dropout_rate: self.layers.append(nn.Dropout(p=dropout_rate)) - self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size)) + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=output_size) + ) self.layers = nn.Sequential(*self.layers) @@ -57,25 +66,7 @@ class MLP(nn.Module): x = torch.flatten(x, start_dim=1) return self.layers(x) - -# def test(): -# x = torch.randn([1, 28, 28]) -# input_size = torch.flatten(x).shape[0] -# output_size = 10 -# hidden_size = 128 -# num_layers = 5 -# dropout_rate = 0.25 -# activation_fn = nn.GELU() -# net = MLP( -# input_size=input_size, -# output_size=output_size, -# hidden_size=hidden_size, -# num_layers=num_layers, -# dropout_rate=dropout_rate, -# activation_fn=activation_fn, -# ) -# from torchsummary import summary -# -# summary(net, (1, 28, 28), device="cpu") -# -# out = net(x) + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" |