From 1f459ba19422593de325983040e176f97cf4ffc0 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 20 Aug 2020 22:18:35 +0200 Subject: A lot of stuff working :D. ResNet implemented! --- src/text_recognizer/networks/mlp.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) (limited to 'src/text_recognizer/networks/mlp.py') diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index ac2c825..acebdaa 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class MLP(nn.Module): """Multi layered perceptron network.""" @@ -16,8 +18,7 @@ class MLP(nn.Module): hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: str = "relu", ) -> None: """Initialization of the MLP network. @@ -27,18 +28,13 @@ class MLP(nn.Module): hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to - None. - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) if isinstance(hidden_size, int): hidden_size = [hidden_size] * num_layers @@ -65,7 +61,7 @@ class MLP(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) -- cgit v1.2.3-70-g09d2