From 7c4de6d88664d2ea1b084f316a11896dde3e1150 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 23 Jun 2020 22:39:54 +0200 Subject: latest --- src/text_recognizer/networks/mlp.py | 81 +++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/text_recognizer/networks/mlp.py (limited to 'src/text_recognizer/networks/mlp.py') diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py new file mode 100644 index 0000000..2a41790 --- /dev/null +++ b/src/text_recognizer/networks/mlp.py @@ -0,0 +1,81 @@ +"""Defines the MLP network.""" +from typing import Callable, Optional + +import torch +from torch import nn + + +class MLP(nn.Module): + """Multi layered perceptron network.""" + + def __init__( + self, + input_size: int, + output_size: int, + hidden_size: int, + num_layers: int, + dropout_rate: float, + activation_fn: Optional[Callable] = None, + ) -> None: + """Initialization of the MLP network. + + Args: + input_size (int): The input shape of the network. + output_size (int): Number of classes in the dataset. + hidden_size (int): The number of `neurons` in each hidden layer. + num_layers (int): The number of hidden layers. + dropout_rate (float): The dropout rate at each layer. + activation_fn (Optional[Callable]): The activation function in the hidden layers, (default: + nn.ReLU()). + + """ + super().__init__() + + if activation_fn is None: + activation_fn = nn.ReLU(inplace=True) + + self.layers = [ + nn.Linear(in_features=input_size, out_features=hidden_size), + activation_fn, + ] + + for _ in range(num_layers): + self.layers += [ + nn.Linear(in_features=hidden_size, out_features=hidden_size), + activation_fn, + ] + + if dropout_rate: + self.layers.append(nn.Dropout(p=dropout_rate)) + + self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size)) + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward.""" + x = torch.flatten(x, start_dim=1) + return self.layers(x) + + +# def test(): +# x = torch.randn([1, 28, 28]) +# input_size = torch.flatten(x).shape[0] +# output_size = 10 +# hidden_size = 128 +# num_layers = 5 +# dropout_rate = 0.25 +# activation_fn = nn.GELU() +# net = MLP( +# input_size=input_size, +# output_size=output_size, +# hidden_size=hidden_size, +# num_layers=num_layers, +# dropout_rate=dropout_rate, +# activation_fn=activation_fn, +# ) +# from torchsummary import summary +# +# summary(net, (1, 28, 28), device="cpu") +# +# out = net(x) -- cgit v1.2.3-70-g09d2